Skip to content

Commit 3796629

Browse files
authored
Put wer tag entity type in SBS output (#32)
* put wer tag entity type in sbs * add example in doc * Input flag documentation * fix link * fix header, unused code * test headers
1 parent 2456389 commit 3796629

File tree

11 files changed

+167
-89
lines changed

11 files changed

+167
-89
lines changed

docs/Advanced-Usage.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ Normalizations are a similar concept to synonyms. They allow a token or group of
7272
}
7373
```
7474

75+
### WER Sidecar
76+
77+
CLI flag: `--wer-sidecar`
78+
79+
Only usable for NLP format reference files. This passes a [WER sidecar](https://github.com/revdotcom/fstalign/blob/develop/docs//NLP-Format.md#wer-tag-sidecar) file to
80+
add extra information to some outputs. Optional.
81+
7582
## Outputs
7683

7784
### Text Log

docs/NLP-Format.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,19 @@ first|0||||LC|['6:DATE']|['6']
2525
quarter|0||||LC|['6:DATE']|['6']
2626
2020|0||||CA|['0:YEAR']|['0', '1', '6']
2727
NexGEn|0||||MC|['7:ORG']|['7']
28-
```
28+
```
29+
30+
## WER tag sidecar
31+
32+
WER tag sidecar files contain accompanying info for tokens in an NLP file. The
33+
keys are IDs corresponding to tokens in the NLP file `wer_tags` column. The
34+
objects under the keys are information about the token.
35+
36+
Example:
37+
```
38+
{
39+
'0': {'entity_type': 'YEAR'},
40+
'1': {'entity_type': 'CARDINAL'},
41+
'6': {'entity_type': 'SPACY>TIME'},
42+
}
43+
```

src/Nlp.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,36 @@
1515
/***********************************
1616
NLP FstLoader class start
1717
************************************/
18-
NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization)
19-
: NlpFstLoader(records, normalization, true) {}
18+
NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization,
19+
Json::Value wer_sidecar)
20+
: NlpFstLoader(records, normalization, wer_sidecar, true) {}
2021

21-
NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, bool processLabels)
22+
NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization,
23+
Json::Value wer_sidecar, bool processLabels)
2224
: FstLoader() {
2325
mNlpRows = records;
2426
mJsonNorm = normalization;
27+
mWerSidecar = wer_sidecar;
2528
std::string last_label;
2629
bool firstTk = true;
2730

31+
2832
// fuse multiple rows that have the same id/label into one entry only
2933
for (auto &row : mNlpRows) {
3034
auto curr_tk = row.token;
3135
auto curr_label = row.best_label;
3236
auto curr_label_id = row.best_label_id;
37+
auto curr_row_tags = row.wer_tags;
38+
// Update wer tags in records to real string labels
39+
vector<string> real_wer_tags;
40+
for (auto &tag: curr_row_tags) {
41+
auto real_tag = tag;
42+
if (mWerSidecar != Json::nullValue) {
43+
real_tag = "###"+ real_tag + "_" + mWerSidecar[real_tag]["entity_type"].asString() + "###";
44+
}
45+
real_wer_tags.push_back(real_tag);
46+
}
47+
row.wer_tags = real_wer_tags;
3348
std::string speaker = row.speakerId;
3449

3550
if (processLabels && curr_label != "") {

src/Nlp.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ class NlpReader {
4242

4343
class NlpFstLoader : public FstLoader {
4444
public:
45-
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, bool processLabels);
46-
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization);
45+
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, Json::Value wer_sidecar, bool processLabels);
46+
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, Json::Value wer_sidecar);
4747
virtual ~NlpFstLoader();
4848
virtual void addToSymbolTable(fst::SymbolTable &symbol) const;
4949
virtual fst::StdVectorFst convertToFst(const fst::SymbolTable &symbol, std::vector<int> map) const;
@@ -53,6 +53,7 @@ class NlpFstLoader : public FstLoader {
5353
vector<RawNlpRecord> mNlpRows;
5454
vector<std::string> mSpeakers;
5555
Json::Value mJsonNorm;
56+
Json::Value mWerSidecar;
5657
virtual const std::string &getToken(int index) const { return mToken.at(index); }
5758
};
5859

src/fstalign.cpp

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -636,34 +636,29 @@ void HandleWer(FstLoader *refLoader, FstLoader *hypLoader, SynonymEngine *engine
636636
CalculatePrecisionRecall(topAlignment, alignerOptions.pr_threshold);
637637

638638
RecordWer(topAlignment);
639-
if (!output_sbs.empty()) {
640-
logger->info("output_sbs = {}", output_sbs);
641-
WriteSbs(topAlignment, output_sbs);
639+
vector<shared_ptr<Stitching>> stitches;
640+
CtmFstLoader *ctm_hyp_loader = dynamic_cast<CtmFstLoader *>(hypLoader);
641+
NlpFstLoader *nlp_hyp_loader = dynamic_cast<NlpFstLoader *>(hypLoader);
642+
OneBestFstLoader *best_loader = dynamic_cast<OneBestFstLoader *>(hypLoader);
643+
if (ctm_hyp_loader) {
644+
stitches = make_stitches(topAlignment, ctm_hyp_loader->mCtmRows, {});
645+
} else if (nlp_hyp_loader) {
646+
stitches = make_stitches(topAlignment, {}, nlp_hyp_loader->mNlpRows);
647+
} else if (best_loader) {
648+
vector<string> tokens;
649+
tokens.reserve(best_loader->TokensSize());
650+
for (int i = 0; i < best_loader->TokensSize(); i++) {
651+
string token = best_loader->getToken(i);
652+
tokens.push_back(token);
653+
}
654+
stitches = make_stitches(topAlignment, {}, {}, tokens);
655+
} else {
656+
stitches = make_stitches(topAlignment);
642657
}
643658

644659
NlpFstLoader *nlp_ref_loader = dynamic_cast<NlpFstLoader *>(refLoader);
645660
if (nlp_ref_loader) {
646661
// We have an NLP reference, more metadata (e.g. speaker info) is available
647-
vector<shared_ptr<Stitching>> stitches;
648-
CtmFstLoader *ctm_hyp_loader = dynamic_cast<CtmFstLoader *>(hypLoader);
649-
NlpFstLoader *nlp_hyp_loader = dynamic_cast<NlpFstLoader *>(hypLoader);
650-
OneBestFstLoader *best_loader = dynamic_cast<OneBestFstLoader *>(hypLoader);
651-
if (ctm_hyp_loader) {
652-
stitches = make_stitches(topAlignment, ctm_hyp_loader->mCtmRows, {});
653-
} else if (nlp_hyp_loader) {
654-
stitches = make_stitches(topAlignment, {}, nlp_hyp_loader->mNlpRows);
655-
} else if (best_loader) {
656-
vector<string> tokens;
657-
tokens.reserve(best_loader->TokensSize());
658-
for (int i = 0; i < best_loader->TokensSize(); i++) {
659-
string token = best_loader->getToken(i);
660-
tokens.push_back(token);
661-
}
662-
stitches = make_stitches(topAlignment, {}, {}, tokens);
663-
} else {
664-
stitches = make_stitches(topAlignment);
665-
}
666-
667662
// Align stitches to the NLP, so stitches can access metadata
668663
try {
669664
align_stitches_to_nlp(nlp_ref_loader, &stitches);
@@ -693,6 +688,11 @@ void HandleWer(FstLoader *refLoader, FstLoader *hypLoader, SynonymEngine *engine
693688
}
694689
}
695690

691+
if (!output_sbs.empty()) {
692+
logger->info("output_sbs = {}", output_sbs);
693+
WriteSbs(topAlignment, stitches, output_sbs);
694+
}
695+
696696
if (!output_nlp.empty() && !nlp_ref_loader) {
697697
logger->warn("Attempted to output an Aligned NLP file without NLP reference, skipping output.");
698698
}

src/main.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ int main(int argc, char **argv) {
1616
setlocale(LC_ALL, "en_US.UTF-8");
1717
string ref_filename;
1818
string json_norm_filename;
19+
string wer_sidecar_filename;
1920
string hyp_filename;
2021
string log_filename = "";
2122
string output_nlp = "";
@@ -94,6 +95,8 @@ int main(int argc, char **argv) {
9495
c->add_option("--composition-approach", composition_approach,
9596
"Desired composition logic. Choices are 'standard' or 'adapted'");
9697
}
98+
get_wer->add_option("--wer-sidecar", wer_sidecar_filename,
99+
"WER sidecar json file.");
97100

98101
get_wer->add_option("--speaker-switch-context", speaker_switch_context_size,
99102
"Amount of context (in each direction) around "
@@ -166,6 +169,27 @@ int main(int argc, char **argv) {
166169
Json::parseFromStream(builder, ss, &obj, &errs);
167170
}
168171

172+
Json::Value wer_sidecar_obj;
173+
if (!wer_sidecar_filename.empty()) {
174+
console->info("reading wer sidecar info from {}", wer_sidecar_filename);
175+
ifstream ifs(wer_sidecar_filename);
176+
177+
Json::CharReaderBuilder builder;
178+
builder["collectComments"] = false;
179+
180+
JSONCPP_STRING errs;
181+
Json::parseFromStream(builder, ifs, &wer_sidecar_obj, &errs);
182+
183+
console->info("The json we just read [{}] has {} elements from its root", wer_sidecar_filename, wer_sidecar_obj.size());
184+
} else {
185+
stringstream ss;
186+
ss << "{}";
187+
188+
Json::CharReaderBuilder builder;
189+
JSONCPP_STRING errs;
190+
Json::parseFromStream(builder, ss, &wer_sidecar_obj, &errs);
191+
}
192+
169193
Json::Value hyp_json_obj;
170194
if (!hyp_json_norm_filename.empty()) {
171195
console->info("reading hypothesis json norm info from {}", hyp_json_norm_filename);
@@ -194,7 +218,7 @@ int main(int argc, char **argv) {
194218
NlpReader nlpReader = NlpReader();
195219
console->info("reading reference nlp from {}", ref_filename);
196220
auto vec = nlpReader.read_from_disk(ref_filename);
197-
NlpFstLoader *nlpFst = new NlpFstLoader(vec, obj, true);
221+
NlpFstLoader *nlpFst = new NlpFstLoader(vec, obj, wer_sidecar_obj, true);
198222
ref = nlpFst;
199223
} else if (EndsWithCaseInsensitive(ref_filename, string(".ctm"))) {
200224
console->info("reading reference ctm from {}", ref_filename);
@@ -212,11 +236,19 @@ int main(int argc, char **argv) {
212236
// loading "hypothesis" inputs
213237
if (EndsWithCaseInsensitive(hyp_filename, string(".nlp"))) {
214238
console->info("reading hypothesis nlp from {}", hyp_filename);
239+
// Make empty json for wer sidecar
240+
Json::Value hyp_empty_json;
241+
stringstream ss;
242+
ss << "{}";
243+
244+
Json::CharReaderBuilder builder;
245+
JSONCPP_STRING errs;
246+
Json::parseFromStream(builder, ss, &hyp_empty_json, &errs);
215247
NlpReader nlpReader = NlpReader();
216248
auto vec = nlpReader.read_from_disk(hyp_filename);
217249
// for now, nlp files passed as hypothesis won't have their labels handled as such
218250
// this also mean that json normalization will be ignored
219-
NlpFstLoader *nlpFst = new NlpFstLoader(vec, hyp_json_obj, false);
251+
NlpFstLoader *nlpFst = new NlpFstLoader(vec, hyp_json_obj, hyp_empty_json, false);
220252
hyp = nlpFst;
221253
} else if (EndsWithCaseInsensitive(hyp_filename, string(".ctm"))) {
222254
console->info("reading hypothesis ctm from {}", hyp_filename);

src/wer.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -327,16 +327,19 @@ void RecordTagWer(vector<shared_ptr<Stitching>> stitches) {
327327
for (auto &stitch : stitches) {
328328
if (!stitch->nlpRow.wer_tags.empty()) {
329329
for (auto wer_tag : stitch->nlpRow.wer_tags) {
330-
wer_results.insert(std::pair<std::string, WerResult>(wer_tag, {0, 0, 0, 0, 0}));
330+
int tag_start = wer_tag.find_first_not_of('#');
331+
int tag_end = wer_tag.find('_');
332+
string wer_tag_id = wer_tag.substr(tag_start, tag_end - tag_start);
333+
wer_results.insert(std::pair<std::string, WerResult>(wer_tag_id, {0, 0, 0, 0, 0}));
331334
// Check with rfind since other comments can be there
332335
bool del = stitch->comment.rfind("del", 0) == 0;
333336
bool ins = stitch->comment.rfind("ins", 0) == 0;
334337
bool sub = stitch->comment.rfind("sub", 0) == 0;
335-
wer_results[wer_tag].insertions += ins;
336-
wer_results[wer_tag].deletions += del;
337-
wer_results[wer_tag].substitutions += sub;
338+
wer_results[wer_tag_id].insertions += ins;
339+
wer_results[wer_tag_id].deletions += del;
340+
wer_results[wer_tag_id].substitutions += sub;
338341
if (!ins) {
339-
wer_results[wer_tag].numWordsInReference += 1;
342+
wer_results[wer_tag_id].numWordsInReference += 1;
340343
}
341344
}
342345
}
@@ -503,7 +506,7 @@ void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp)
503506
hyp = "";
504507
}
505508

506-
void WriteSbs(spWERA topAlignment, string sbs_filename) {
509+
void WriteSbs(spWERA topAlignment, vector<shared_ptr<Stitching>> stitches, string sbs_filename) {
507510
auto logger = logger::GetOrCreateLogger("wer");
508511
logger->set_level(spdlog::level::info);
509512

@@ -514,7 +517,7 @@ void WriteSbs(spWERA topAlignment, string sbs_filename) {
514517
triple *tk_pair = new triple();
515518
string prev_tk_classLabel = "";
516519
logger->info("Side-by-Side alignment info going into {}", sbs_filename);
517-
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}", "ref_token", "hyp_token", "IsErr", "Class") << endl;
520+
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", "ref_token", "hyp_token", "IsErr", "Class", "Wer_Tag_Entities") << endl;
518521

519522
// keep track of error groupings
520523
ErrorGroups groups_err;
@@ -525,10 +528,15 @@ void WriteSbs(spWERA topAlignment, string sbs_filename) {
525528
std::set<std::string> op_set = {"<ins>", "<del>", "<sub>"};
526529

527530
size_t offset = 2; // line number in output file where first triple starts
528-
while (visitor.NextTriple(tk_pair)) {
529-
string tk_classLabel = tk_pair->classLabel;
530-
string ref_tk = tk_pair->ref;
531-
string hyp_tk = tk_pair->hyp;
531+
for (auto p_stitch: stitches) {
532+
string tk_classLabel = p_stitch->classLabel;
533+
string tk_wer_tags = "";
534+
auto wer_tags = p_stitch->nlpRow.wer_tags;
535+
for (auto wer_tag: wer_tags) {
536+
tk_wer_tags = tk_wer_tags + wer_tag + "|";
537+
}
538+
string ref_tk = p_stitch->reftk;
539+
string hyp_tk = p_stitch->hyptk;
532540
string tag = "";
533541

534542
if (ref_tk == NOOP) {
@@ -560,7 +568,7 @@ void WriteSbs(spWERA topAlignment, string sbs_filename) {
560568
eff_class = tk_classLabel;
561569
}
562570

563-
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}", ref_tk, hyp_tk, tag, eff_class) << endl;
571+
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", ref_tk, hyp_tk, tag, eff_class, tk_wer_tags) << endl;
564572
offset++;
565573
}
566574

src/wer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,4 @@ void CalculatePrecisionRecall(spWERA &topAlignment, int threshold);
4848
typedef vector<pair<size_t, string>> ErrorGroups;
4949

5050
void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp);
51-
void WriteSbs(spWERA topAlignment, string sbs_filename);
51+
void WriteSbs(spWERA topAlignment, vector<shared_ptr<Stitching>> stitches, string sbs_filename);

test/data/syn_1.hyp.sbs

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
1-
ref_token hyp_token IsErr Class
2-
we <del> ERR
3-
will we'll ERR
4-
have have
5-
a a
6-
nice nice
7-
evening evening
8-
<ins> um ERR
9-
no no
10-
matter matter
11-
what what
12-
will will
13-
happen happen
14-
<ins> it ERR
15-
um is ERR
16-
it's uh ERR
17-
a a
18-
good good
19-
opportunity opportunity
20-
to to
21-
do <del> ERR
22-
this this
23-
you'll you'll
24-
<ins> uh ERR
25-
see see
1+
ref_token hyp_token IsErr Class Wer_Tag_Entities
2+
we <del> ERR
3+
will we'll ERR
4+
have have
5+
a a
6+
nice nice
7+
evening evening
8+
<ins> um ERR
9+
no no
10+
matter matter
11+
what what
12+
will will
13+
happen happen
14+
<ins> it ERR
15+
um is ERR
16+
it's uh ERR
17+
a a
18+
good good
19+
opportunity opportunity
20+
to to
21+
do <del> ERR
22+
this this
23+
you'll you'll
24+
<ins> uh ERR
25+
see see
2626
------------------------------------------------------------
2727
Line Group
2828
2 we will <-> we'll

test/data/twenty.hyp-a2.sbs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
ref_token hyp_token IsErr Class
2-
20 <del> ERR ___1_CARDINAL___
3-
in in
4-
twenty twenty ___2_YEAR___
5-
twenty thirty ERR ___2_YEAR___
6-
is is
7-
one one ___3_CARDINAL___
8-
twenty twenty ___3_CARDINAL___
9-
<ins> two ERR ___3_CARDINAL___
10-
three three ___3_CARDINAL___
1+
ref_token hyp_token IsErr Class Wer_Tag_Entities
2+
20 <del> ERR ___1_CARDINAL___
3+
in in
4+
twenty twenty ___2_YEAR___
5+
twenty thirty ERR ___2_YEAR___
6+
is is
7+
one one ___3_CARDINAL___
8+
twenty twenty ___3_CARDINAL___
9+
<ins> two ERR ___3_CARDINAL___
10+
three three ___3_CARDINAL___
1111
------------------------------------------------------------
1212
Line Group
1313
2 20 <-> ***

0 commit comments

Comments
 (0)