Skip to content

Commit

Permalink
add word error count (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikvaessen authored Feb 3, 2025
1 parent 2a036a1 commit be6b690
Show file tree
Hide file tree
Showing 5 changed files with 296 additions and 1,713 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ poetry.lock
# virtual environments
venv
.venv
uv.lock

# cache folders
.pytest_cache
.benchmarks
/docs/site/
/site/

139 changes: 129 additions & 10 deletions src/jiwer/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
#

"""
Utility method to visualize the alignment between one or more reference and hypothesis
pairs.
Utility method to visualize the alignment and errors between one or more reference
and hypothesis pairs.
"""

from collections import defaultdict
from typing import List, Union, Optional

from jiwer.process import CharacterOutput, WordOutput, AlignmentChunk

__all__ = ["visualize_alignment"]
__all__ = ["visualize_alignment", "collect_error_counts", "visualize_error_counts"]


def visualize_alignment(
Expand Down Expand Up @@ -65,16 +65,19 @@ def visualize_alignment(
```
will produce this visualization:
```txt
sentence 1
=== SENTENCE 1 ===
REF: # short one here
HYP: shoe order one *
I S D
sentence 2
=== sentence 2 ===
REF: quite a bit of # # longer sentence #
HYP: quite * bit of an even longest sentence here
D I I S I
=== SUMMARY ===
number of sentences: 2
substitutions=2 deletions=2 insertions=4 hits=5
Expand All @@ -87,12 +90,13 @@ def visualize_alignment(
When `show_measures=False`, only the alignment will be printed:
```txt
sentence 1
=== SENTENCE 1 ===
REF: # short one here
HYP: shoe order one *
I S D
sentence 2
=== SENTENCE 2 ===
REF: quite a bit of # # longer sentence #
HYP: quite * bit of an even longest sentence here
D I I S I
Expand All @@ -101,7 +105,7 @@ def visualize_alignment(
When setting `line_width=80`, the following output will be split into multiple lines:
```txt
sentence 1
=== SENTENCE 1 ===
REF: This is a very long sentence that is *** much longer than the previous one
HYP: This is a very loong sentence that is not much longer than the previous one
S I
Expand All @@ -122,13 +126,14 @@ def visualize_alignment(
):
continue

final_str += f"sentence {idx+1}\n"
final_str += f"=== SENTENCE {idx+1} ===\n\n"
final_str += _construct_comparison_string(
gt, hp, chunks, include_space_seperator=not is_cer, line_width=line_width
)
final_str += "\n"

if show_measures:
final_str += "=== SUMMARY ===\n"
final_str += f"number of sentences: {len(alignment)}\n"
final_str += f"substitutions={output.substitutions} "
final_str += f"deletions={output.deletions} "
Expand Down Expand Up @@ -213,3 +218,117 @@ def _construct_comparison_string(
return agg_str + f"{ref_str[:-1]}\n{hyp_str[:-1]}\n{op_str[:-1]}\n"
else:
return agg_str + f"{ref_str}\n{hyp_str}\n{op_str}\n"


def collect_error_counts(output: Union[WordOutput, CharacterOutput]):
"""
Retrieve three dictionaries, which count the frequency of how often
each word or character was substituted, inserted, or deleted.
The substitution dictionary has, as keys, a 2-tuple (from, to).
The other two dictionaries have the inserted/deleted words or characters as keys.
Args:
output: The processed output of reference and hypothesis pair(s).
Returns:
A three-tuple of dictionaries, in the order substitutions, insertions, deletions.
"""
substitutions = defaultdict(lambda: 0)
insertions = defaultdict(lambda: 0)
deletions = defaultdict(lambda: 0)

for idx, sentence_chunks in enumerate(output.alignments):
ref = output.references[idx]
hyp = output.hypotheses[idx]
sep = " " if isinstance(output, WordOutput) else ""

for chunk in sentence_chunks:
if chunk.type == "insert":
inserted = sep.join(hyp[chunk.hyp_start_idx : chunk.hyp_end_idx])
insertions[inserted] += 1
if chunk.type == "delete":
deleted = sep.join(ref[chunk.ref_start_idx : chunk.ref_end_idx])
deletions[deleted] += 1
if chunk.type == "substitute":
replaced = sep.join(ref[chunk.ref_start_idx : chunk.ref_end_idx])
by = sep.join(hyp[chunk.hyp_start_idx : chunk.hyp_end_idx])
substitutions[(replaced, by)] += 1

return substitutions, insertions, deletions


def visualize_error_counts(
output: Union[WordOutput, CharacterOutput],
show_substitutions: bool = True,
show_insertions: bool = True,
show_deletions: bool = True,
top_k: Optional[int] = None,
):
"""
Visualize which words (or characters), and how often, were substituted, inserted, or deleted.
Args:
output:
show_substitutions: If true, visualize substitution errors.
show_insertions: If true, visualize insertion errors.
show_deletions: If true, visualize deletion errors.
top_k: If set, only visualize the k most frequent errors.
Returns: A string which visualizes the words/characters and their frequencies.
"""
s, i, d = collect_error_counts(output)

def build_list(errors: dict):
if len(errors) == 0:
return "none"

keys = [k for k in errors.keys()]
keys = sorted(keys, reverse=True, key=lambda k: errors[k])

if top_k is not None:
keys = keys[:top_k]

# we get the maximum length of all words to nicely pad output
ln = max(len(k) if isinstance(k, str) else max(len(e) for e in k) for k in keys)

# here we construct the string
build = ""

for count, (k, v) in enumerate(
sorted(errors.items(), key=lambda tpl: tpl[1], reverse=True)
):
if top_k is not None and count >= top_k:
break

if isinstance(k, tuple):
build += f"{k[0]: <{ln}} --> {k[1]:<{ln}} = {v}x\n"
else:
build += f"{k:<{ln}} = {v}x\n"

return build

output = ""

if show_substitutions:
if output != "":
output += "\n"
output += "=== SUBSTITUTIONS ===\n"
output += build_list(s)

if show_insertions:
if output != "":
output += "\n"
output += "=== INSERTIONS ===\n"
output += build_list(i)

if show_deletions:
if output != "":
output += "\n"
output += "=== DELETIONS ===\n"
output += build_list(d)

if output[-1:] == "\n":
output = output[:-1]

return output
Loading

0 comments on commit be6b690

Please sign in to comment.