Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow a regular expression to describe tokens to suppress #1997

Merged
merged 3 commits into from Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -148,6 +148,9 @@ public void tdrzEnable(boolean enable) {
tdrz_enable = enable ? CBool.TRUE : CBool.FALSE;
}

/** Regular expression matching tokens to suppress. */
public String suppress_regex;

/** Tokens to provide to the whisper decoder as an initial prompt.
* These are prepended to any existing text context from a previous call. */
public String initial_prompt;
Expand Down Expand Up @@ -319,7 +322,7 @@ protected List<String> getFieldOrder() {
"no_context", "single_segment", "no_timestamps",
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
"tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
"new_segment_callback", "new_segment_callback_user_data",
Expand Down
7 changes: 7 additions & 0 deletions examples/command/command.cpp
Expand Up @@ -52,6 +52,9 @@ struct whisper_params {
std::string prompt;
std::string context;
std::string grammar;

// A regular expression that matches tokens to suppress
std::string suppress_regex;
};

void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
Expand Down Expand Up @@ -85,6 +88,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
Expand Down Expand Up @@ -122,6 +126,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
fprintf(stderr, "\n");
}

Expand Down Expand Up @@ -167,6 +172,8 @@ std::string transcribe(

wparams.initial_prompt = params.context.data();

wparams.suppress_regex = params.suppress_regex.c_str();

const auto & grammar_parsed = params.grammar_parsed;
auto grammar_rules = grammar_parsed.c_rules();

Expand Down
8 changes: 8 additions & 0 deletions examples/main/main.cpp
Expand Up @@ -6,6 +6,7 @@
#include <cmath>
#include <fstream>
#include <cstdio>
#include <regex>
#include <string>
#include <thread>
#include <vector>
Expand Down Expand Up @@ -78,6 +79,9 @@ struct whisper_params {
// [TDRZ] speaker turn string
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line

// A regular expression that matches tokens to suppress
std::string suppress_regex;

std::string openvino_encode_device = "CPU";

std::string dtw = "";
Expand Down Expand Up @@ -160,6 +164,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
Expand Down Expand Up @@ -223,6 +228,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
Expand Down Expand Up @@ -1033,6 +1039,8 @@ int main(int argc, char ** argv) {

wparams.tdrz_enable = params.tinydiarize; // [TDRZ]

wparams.suppress_regex = params.suppress_regex.c_str();

wparams.initial_prompt = params.prompt.c_str();

wparams.greedy.best_of = params.best_of;
Expand Down
13 changes: 13 additions & 0 deletions whisper.cpp
Expand Up @@ -4553,6 +4553,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str

/*.tdrz_enable =*/ false,

/* suppress_regex =*/ nullptr,

/*.initial_prompt =*/ nullptr,
/*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,
Expand Down Expand Up @@ -4796,6 +4798,17 @@ static void whisper_process_logits(
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
}

// suppress any tokens matching a regular expression
// ref: https://github.com/openai/whisper/discussions/1041
if (params.suppress_regex != nullptr) {
std::regex re(params.suppress_regex);
for (std::pair<whisper_vocab::token, whisper_vocab::id> token_id : vocab.token_to_id) {
if (std::regex_match(token_id.first, re)) {
logits[token_id.second] = -INFINITY;
}
}
}

// suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
if (params.suppress_non_speech_tokens) {
Expand Down
3 changes: 3 additions & 0 deletions whisper.h
Expand Up @@ -505,6 +505,9 @@ extern "C" {
// [EXPERIMENTAL] [TDRZ] tinydiarize
bool tdrz_enable; // enable tinydiarize speaker turn detection

// A regular expression that matches tokens to suppress
const char * suppress_regex;

// tokens to provide to the whisper decoder as initial prompt
// these are prepended to any existing text context from a previous call
// use whisper_tokenize() to convert text to tokens
Expand Down