Skip to content

Commit a1bcf8d

Browse files
authored
Merge pull request #2183 from coqui-ai/decoder-wav2vec2-batch
Expose batch version of decoder API for wav2vec2 AM
2 parents bfbf259 + 480c767 commit a1bcf8d

File tree

5 files changed

+82
-15
lines changed

5 files changed

+82
-15
lines changed

.github/workflows/build-and-test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,6 +1567,9 @@ jobs:
15671567
run: |
15681568
sudo rm -r /Library/Developer/CommandLineTools
15691569
- run: ./ci_scripts/host-build.sh ${{ matrix.arch }}
1570+
- name: Setup tmate session
1571+
uses: mxschmitt/action-tmate@v3
1572+
if: failure()
15701573
- run: ./ci_scripts/package.sh
15711574
- uses: actions/upload-artifact@v2
15721575
with:

native_client/ctcdecode/__init__.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,78 @@ def ctc_beam_search_decoder_batch(
311311
return batch_beam_results
312312

313313

314+
def ctc_beam_search_decoder_for_wav2vec2am_batch(
315+
probs_seq,
316+
seq_lengths,
317+
alphabet,
318+
beam_size,
319+
num_threads,
320+
cutoff_prob=1.0,
321+
cutoff_top_n=40,
322+
blank_id=-1,
323+
ignored_symbols=frozenset(),
324+
scorer=None,
325+
hot_words=dict(),
326+
num_results=1,
327+
):
328+
"""Wrapper for the batched CTC beam search decoder for wav2vec2 AM.
329+
330+
:param probs_seq: 3-D list with each element as an instance of 2-D list
331+
of probabilities used by ctc_beam_search_decoder().
332+
:type probs_seq: 3-D list
333+
:param alphabet: alphabet list.
334+
:alphabet: Alphabet
335+
:param beam_size: Width for beam search.
336+
:type beam_size: int
337+
:param num_threads: Number of threads to use for processing batch.
338+
:type num_threads: int
339+
:param cutoff_prob: Cutoff probability in alphabet pruning,
340+
default 1.0, no pruning.
341+
:type cutoff_prob: float
342+
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
343+
characters with highest probs in alphabet will be
344+
used in beam search, default 40.
345+
:type cutoff_top_n: int
346+
:param scorer: External scorer for partially decoded sentence, e.g. word
347+
count or language model.
348+
:type scorer: Scorer
349+
:param hot_words: Map of words (keys) to their assigned boosts (values)
350+
:type hot_words: dict[string, float]
351+
:param num_results: Number of beams to return.
352+
:type num_results: int
353+
:return: List of tuples of confidence and sentence as decoding
354+
results, in descending order of the confidence.
355+
:rtype: list
356+
"""
357+
batch_beam_results = swigwrapper.ctc_beam_search_decoder_for_wav2vec2am_batch(
358+
probs_seq,
359+
seq_lengths,
360+
alphabet,
361+
beam_size,
362+
num_threads,
363+
cutoff_prob,
364+
cutoff_top_n,
365+
blank_id,
366+
ignored_symbols,
367+
scorer,
368+
hot_words,
369+
num_results,
370+
)
371+
batch_beam_results = [
372+
[
373+
DecodeResult(
374+
res.confidence,
375+
alphabet.Decode(res.tokens),
376+
[int(t) for t in res.tokens],
377+
[int(t) for t in res.timesteps],
378+
)
379+
for res in beam_results
380+
]
381+
for beam_results in batch_beam_results
382+
]
383+
return batch_beam_results
384+
385+
314386
class FlashlightDecoderState(swigwrapper.FlashlightDecoderState):
315387
"""
316388
This class contains constants used to specify the desired behavior for the

native_client/ctcdecode/ctc_beam_search_decoder.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ ctc_beam_search_decoder_batch(
652652
}
653653

654654
std::vector<std::vector<Output>>
655-
ctc_beam_search_decoder_batch_for_wav2vec2am(
655+
ctc_beam_search_decoder_for_wav2vec2am_batch(
656656
const double *probs,
657657
int batch_size,
658658
int time_dim,
@@ -661,7 +661,7 @@ ctc_beam_search_decoder_batch_for_wav2vec2am(
661661
int seq_lengths_size,
662662
const Alphabet &alphabet,
663663
size_t beam_size,
664-
size_t num_processes,
664+
size_t num_threads,
665665
double cutoff_prob,
666666
size_t cutoff_top_n,
667667
int blank_id,
@@ -670,10 +670,10 @@ ctc_beam_search_decoder_batch_for_wav2vec2am(
670670
std::unordered_map<std::string, float> hot_words,
671671
size_t num_results)
672672
{
673-
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
673+
VALID_CHECK_GT(num_threads, 0, "num_threads must be nonnegative!");
674674
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element");
675675
// thread pool
676-
ThreadPool pool(num_processes);
676+
ThreadPool pool(num_threads);
677677

678678
// enqueue the tasks of decoding
679679
std::vector<std::future<std::vector<Output>>> res;

native_client/ctcdecode/ctc_beam_search_decoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ ctc_beam_search_decoder_batch(
333333
* result for one audio sample.
334334
*/
335335
std::vector<std::vector<Output>>
336-
ctc_beam_search_decoder_batch_for_wav2vec2am(
336+
ctc_beam_search_decoder_for_wav2vec2am_batch(
337337
const double* probs,
338338
int batch_size,
339339
int time_dim,

native_client/definitions.mk

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,7 @@ CFLAGS := -mmacosx-version-min=10.10 -target x86_64-apple-macos
3838
LDFLAGS := -mmacosx-version-min=10.10 -target x86_64-apple-macos10.10
3939

4040
SOX_CFLAGS := $(shell pkg-config --cflags sox)
41-
LIBSOX_PATH := $(shell echo `pkg-config --libs-only-L sox | sed -e 's/^-L//'`/lib`pkg-config --libs-only-l sox | sed -e 's/^-l//'`.dylib)
42-
LIBOPUSFILE_PATH := $(shell echo `pkg-config --libs-only-L opusfile | sed -e 's/^-L//'`/lib`pkg-config --libs-only-l opusfile | sed -e 's/^-l//'`.dylib)
43-
LIBSOX_STATIC_DEPS := $(shell echo `otool -L $(LIBSOX_PATH) | tail -n +2 | cut -d' ' -f1 | grep /opt/ | sed -E "s/\.[[:digit:]]+\.dylib/\.a/" | tr '\n' ' '`)
44-
LIBOPUSFILE_STATIC_DEPS := $(shell echo `otool -L $(LIBOPUSFILE_PATH) | tail -n +2 | cut -d' ' -f1 | grep /opt/ | sed -E "s/\.[[:digit:]]+\.dylib/\.a/" | tr '\n' ' '`)
45-
SOX_LDFLAGS := $(LIBSOX_STATIC_DEPS) $(LIBOPUSFILE_STATIC_DEPS) -framework CoreAudio -lz
41+
SOX_LDFLAGS := $(shell pkg-config --libs sox) -framework CoreAudio -lz
4642
else
4743
SOX_LDFLAGS := `pkg-config --libs sox`
4844
endif # OS others
@@ -128,11 +124,7 @@ CFLAGS := -mmacosx-version-min=11.0 -target arm64-apple-macos11
128124
LDFLAGS := -mmacosx-version-min=11.0 -target arm64-apple-macos11
129125

130126
SOX_CFLAGS := $(shell arm-pkg-config --cflags sox)
131-
LIBSOX_PATH := $(shell echo `arm-pkg-config --libs-only-L sox | sed -e 's/^-L//'`/lib`arm-pkg-config --libs-only-l sox | sed -e 's/^-l//'`.dylib)
132-
LIBOPUSFILE_PATH := $(shell echo `arm-pkg-config --libs-only-L opusfile | sed -e 's/^-L//'`/lib`arm-pkg-config --libs-only-l opusfile | sed -e 's/^-l//'`.dylib)
133-
LIBSOX_STATIC_DEPS := $(shell echo `otool -L $(LIBSOX_PATH) | tail -n +2 | cut -d' ' -f1 | grep /opt/ | sed -E "s/\.[[:digit:]]+\.dylib/\.a/" | tr '\n' ' '`)
134-
LIBOPUSFILE_STATIC_DEPS := $(shell echo `otool -L $(LIBOPUSFILE_PATH) | tail -n +2 | cut -d' ' -f1 | grep /opt/ | sed -E "s/\.[[:digit:]]+\.dylib/\.a/" | tr '\n' ' '`)
135-
SOX_LDFLAGS := $(LIBSOX_STATIC_DEPS) $(LIBOPUSFILE_STATIC_DEPS) -framework CoreAudio -lz
127+
SOX_LDFLAGS := $(shell arm-pkg-config --libs sox) -framework CoreAudio -lz
136128
endif
137129

138130
# -Wl,--no-as-needed is required to force linker not to evict libs it thinks we

0 commit comments

Comments
 (0)