Skip to content

Commit 4c579ad

Browse files
Nerd-1422: Add flag for reading punctuation from nlp as separate tokens (#10)
* Nerd-1422: Add flag for reading punctuation from nlp as separate tokens * test * version file --------- Co-authored-by: Nishchal Bhandari <[email protected]>
1 parent 99afe1b commit 4c579ad

File tree

9 files changed

+139
-8
lines changed

9 files changed

+139
-8
lines changed

src/Nlp.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,21 @@ NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value norma
2020
: NlpFstLoader(records, normalization, wer_sidecar, true) {}
2121

2222
NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization,
23-
Json::Value wer_sidecar, bool processLabels)
23+
Json::Value wer_sidecar, bool processLabels, bool use_punctuation)
2424
: FstLoader() {
25-
mNlpRows = records;
2625
mJsonNorm = normalization;
2726
mWerSidecar = wer_sidecar;
2827
std::string last_label;
2928
bool firstTk = true;
3029

3130

3231
// fuse multiple rows that have the same id/label into one entry only
33-
for (auto &row : mNlpRows) {
32+
for (auto &row : records) {
33+
mNlpRows.push_back(row);
3434
auto curr_tk = row.token;
3535
auto curr_label = row.best_label;
3636
auto curr_label_id = row.best_label_id;
37+
auto punctuation = row.punctuation;
3738
auto curr_row_tags = row.wer_tags;
3839
// Update wer tags in records to real string labels
3940
vector<string> real_wer_tags;
@@ -83,6 +84,14 @@ NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value norma
8384
std::string lower_cased = UnicodeLowercase(curr_tk);
8485
mToken.push_back(lower_cased);
8586
mSpeakers.push_back(speaker);
87+
if (use_punctuation && punctuation != "") {
88+
mToken.push_back(punctuation);
89+
mSpeakers.push_back(speaker);
90+
RawNlpRecord nlp_row = row;
91+
nlp_row.token = nlp_row.punctuation;
92+
nlp_row.punctuation = "";
93+
mNlpRows.push_back(nlp_row);
94+
}
8695
}
8796

8897
firstTk = false;

src/Nlp.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class NlpReader {
4343

4444
class NlpFstLoader : public FstLoader {
4545
public:
46-
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, Json::Value wer_sidecar, bool processLabels);
46+
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, Json::Value wer_sidecar, bool processLabels, bool use_punctuation = false);
4747
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, Json::Value wer_sidecar);
4848
virtual ~NlpFstLoader();
4949
virtual void addToSymbolTable(fst::SymbolTable &symbol) const;

src/fstalign.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,10 @@ void align_stitches_to_nlp(NlpFstLoader *refLoader, vector<shared_ptr<Stitching>
367367
continue;
368368
}
369369

370+
if (nlpRowIndex >= nlpMaxRow) {
371+
logger->warn("Ran out of nlp rows. {} rows, {} stitches", nlpMaxRow, numStitches);
372+
break;
373+
}
370374
auto nlpPart = nlpRows[nlpRowIndex];
371375
string nlp_classLabel = GetClassLabel(nlpPart.best_label);
372376

src/main.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ int main(int argc, char **argv) {
3232
int numBests = 100;
3333
int levenstein_maximum_error_streak = 100;
3434
bool record_case_stats = false;
35+
bool use_punctuation = false;
3536
bool disable_approximate_alignment = false;
3637

3738
bool disable_cutoffs = false;
@@ -120,6 +121,7 @@ int main(int argc, char **argv) {
120121
get_wer->add_flag("--record-case-stats", record_case_stats,
121122
"Record precision/recall for how well the hypothesis"
122123
"casing matches the reference.");
124+
get_wer->add_flag("--use-punctuation", use_punctuation, "Treat punctuation from nlp rows as separate tokens");
123125

124126
// CLI11_PARSE(app, argc, argv);
125127
try {
@@ -218,7 +220,7 @@ int main(int argc, char **argv) {
218220
NlpReader nlpReader = NlpReader();
219221
console->info("reading reference nlp from {}", ref_filename);
220222
auto vec = nlpReader.read_from_disk(ref_filename);
221-
NlpFstLoader *nlpFst = new NlpFstLoader(vec, obj, wer_sidecar_obj, true);
223+
NlpFstLoader *nlpFst = new NlpFstLoader(vec, obj, wer_sidecar_obj, true, use_punctuation);
222224
ref = nlpFst;
223225
} else if (EndsWithCaseInsensitive(ref_filename, string(".ctm"))) {
224226
console->info("reading reference ctm from {}", ref_filename);
@@ -248,7 +250,7 @@ int main(int argc, char **argv) {
248250
auto vec = nlpReader.read_from_disk(hyp_filename);
249251
// for now, nlp files passed as hypothesis won't have their labels handled as such
250252
// this also mean that json normalization will be ignored
251-
NlpFstLoader *nlpFst = new NlpFstLoader(vec, hyp_json_obj, hyp_empty_json, false);
253+
NlpFstLoader *nlpFst = new NlpFstLoader(vec, hyp_json_obj, hyp_empty_json, false, use_punctuation);
252254
hyp = nlpFst;
253255
} else if (EndsWithCaseInsensitive(hyp_filename, string(".ctm"))) {
254256
console->info("reading hypothesis ctm from {}", hyp_filename);

src/version.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#pragma once
22

33
#define FSTALIGNER_VERSION_MAJOR 1
4-
#define FSTALIGNER_VERSION_MINOR 6
5-
#define FSTALIGNER_VERSION_PATCH 1
4+
#define FSTALIGNER_VERSION_MINOR 9
5+
#define FSTALIGNER_VERSION_PATCH 0

test/data/short.aligned.punc.nlp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
token|speaker|ts|endTs|punctuation|prepunctuation|case|tags|wer_tags|oldTs|oldEndTs|ali_comment
2+
<crosstalk>|2|0.0000|0.0000|||LC|[]|[]|||
3+
Yeah|1|0.0000|0.0000|,||UC|[]|[]|||
4+
,|1|0.0000|0.0000|||UC|[]|[]|||
5+
yeah|1|||,||LC|[]|[]|||del
6+
,|1|||||LC|[]|[]|||del
7+
right|1|0.0000|0.0000|.||LC|[]|[]|||
8+
.|1|||||LC|[]|[]|||del
9+
Yeah|1|||,||UC|[]|[]|||del
10+
,|1|||||UC|[]|[]|||del
11+
all|1|||||LC|[]|[]|||del
12+
right|1|||,||LC|[]|[]|||del
13+
,|1|0.0000|0.0000|||LC|[]|[]|||sub(i'll)
14+
probably|1|0.0000|0.0000|||LC|[]|[]|||sub(do)
15+
just|1|0.0000|0.0000|||LC|[]|[]|||
16+
that|1|0.0000|0.0000|.||LC|[]|[]|||
17+
.|1|0.0000|0.0000|||LC|[]|[]|||sub(?)
18+
Are|3|0.0000|0.0000|||UC|[]|[]|||
19+
there|3|0.0000|0.0000|||LC|[]|[]|||
20+
any|3|0.0000|0.0000|||LC|[]|[]|||
21+
visuals|3|0.0000|0.0000|||LC|[]|[]|||
22+
that|3|0.0000|0.0000|||LC|[]|[]|||
23+
come|3|0.0000|0.0000|||LC|[]|[]|||
24+
to|3|0.0000|0.0000|||LC|[]|[]|||
25+
mind|3|0.0000|0.0000|||LC|[]|[]|||
26+
or|3|0.0000|0.0000|||LC|[]|[]|||
27+
Yeah|1|0.0000|0.0000|,||UC|[]|[]|||
28+
,|1|0.0000|0.0000|||UC|[]|[]|||
29+
sure|1|0.0000|0.0000|.||LC|[]|[]|||
30+
.|1|0.0000|0.0000|||LC|[]|[]|||
31+
When|1|0.0000|0.0000|||UC|[]|[]|||
32+
I|1|0.0000|0.0000|||CA|[]|[]|||
33+
hear|1|0.0000|0.0000|||LC|[]|[]|||
34+
Foobar|1|0.0000|0.0000|,||UC|[]|[]|||
35+
,|1|0.0000|0.0000|||UC|[]|[]|||
36+
I|1|0.0000|0.0000|||CA|[]|[]|||
37+
think|1|0.0000|0.0000|||LC|[]|[]|||
38+
about|1|0.0000|0.0000|||LC|[]|[]|||
39+
just|1|0.0000|0.0000|||LC|[]|[]|||
40+
that|1|0.0000|0.0000|:||LC|[]|[]|||
41+
:|1|0.0000|0.0000|||LC|[]|[]|||
42+
foo|1|0.0000|0.0000|||LC|[]|[]|||sub(,)
43+
a|1|0.0000|0.0000|||LC|[]|[]|||

test/data/short_punc.hyp.nlp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
token|speaker|ts|endTs|punctuation|case|tags
2+
<crosstalk>|2||||LC|[]
3+
Yeah|1|||,|UC|[]
4+
right|1||||LC|[]
5+
I'll|1||||UC|[]
6+
do|1||||LC|[]
7+
just|1||||LC|[]
8+
that|1|||?|LC|[]
9+
Are|3||||UC|[]
10+
there|3||||LC|[]
11+
any|3||||LC|[]
12+
visuals|3||||LC|[]
13+
that|3||||LC|[]
14+
come|3||||LC|[]
15+
to|3||||LC|[]
16+
mind|3||||LC|[]
17+
or|3|||?|LC|[]
18+
Yeah|1|||,|UC|[]
19+
sure|1|||.|LC|[]
20+
When|1||||UC|[]
21+
I|1||||CA|[]
22+
hear|1||||LC|[]
23+
Foobar|1|||,|UC|[]
24+
I|1||||CA|[]
25+
think|1||||LC|[]
26+
about|1||||LC|[]
27+
just|1||||LC|[]
28+
that|1|||:|LC|[]
29+
Foobar|1|||,|UC|[]
30+
a|1||||LC|[]

test/data/short_punc.ref.nlp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
token|speaker|ts|endTs|punctuation|case|tags
2+
<crosstalk>|2||||LC|[]
3+
Yeah|1|||,|UC|[]
4+
yeah|1|||,|LC|[]
5+
right|1|||.|LC|[]
6+
Yeah|1|||,|UC|[]
7+
all|1||||LC|[]
8+
right|1|||,|LC|[]
9+
probably|1||||LC|[]
10+
just|1||||LC|[]
11+
that|1|||.|LC|[]
12+
Are|3||||UC|[]
13+
there|3||||LC|[]
14+
any|3||||LC|[]
15+
visuals|3||||LC|[]
16+
that|3||||LC|[]
17+
come|3||||LC|[]
18+
to|3||||LC|[]
19+
mind|3||||LC|[]
20+
or-|3||||LC|[]
21+
Yeah|1|||,|UC|[]
22+
sure|1|||.|LC|[]
23+
When|1||||UC|[]
24+
I|1||||CA|[]
25+
hear|1||||LC|[]
26+
Foobar|1|||,|UC|[]
27+
I|1||||CA|[]
28+
think|1||||LC|[]
29+
about|1||||LC|[]
30+
just|1||||LC|[]
31+
that|1|||:|LC|[]
32+
foo|1||||LC|[]
33+
a|1||||LC|[]

test/fstalign_Test.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,16 @@ TEST_CASE_METHOD(UniqueTestsFixture, "main-adapted-composition()") {
659659
REQUIRE_THAT(result, Contains("WER: INS:0 DEL:3 SUB:3"));
660660
}
661661

662+
SECTION("wer with punctuation(nlp output)") {
663+
const auto result =
664+
exec(command("wer", approach, "short_punc.ref.nlp", "short_punc.hyp.nlp", sbs_output, nlp_output, TEST_SYNONYMS)+" --use-punctuation");
665+
const auto testFile = std::string{TEST_DATA} + "short.aligned.punc.nlp";
666+
667+
REQUIRE(compareFiles(nlp_output.c_str(), testFile.c_str()));
668+
REQUIRE_THAT(result, Contains("WER: 13/42 = 0.3095"));
669+
REQUIRE_THAT(result, Contains("WER: INS:2 DEL:7 SUB:4"));
670+
}
671+
662672
// alignment tests
663673

664674
SECTION("align_1") {

0 commit comments

Comments
 (0)