diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java index 43c9a0dcf34..60d8334b935 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java @@ -148,6 +148,9 @@ public void tdrzEnable(boolean enable) { tdrz_enable = enable ? CBool.TRUE : CBool.FALSE; } + /** Regular expression matching tokens to suppress. */ + public String suppress_regex; + /** Tokens to provide to the whisper decoder as an initial prompt. * These are prepended to any existing text context from a previous call. */ public String initial_prompt; @@ -319,7 +322,7 @@ protected List getFieldOrder() { "no_context", "single_segment", "no_timestamps", "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps", "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx", - "tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", + "tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", "suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty", "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search", "new_segment_callback", "new_segment_callback_user_data", diff --git a/examples/command/command.cpp b/examples/command/command.cpp index f86a3449db7..ec749d60247 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -52,6 +52,9 @@ struct whisper_params { std::string prompt; std::string context; std::string grammar; + + // A regular expression that matches tokens to suppress + std::string suppress_regex; }; void whisper_print_usage(int argc, char ** argv, const whisper_params & params); @@ -85,6 +88,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; } else if ( arg == "--grammar") { params.grammar = argv[++i]; } else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); } + else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -122,6 +126,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str()); fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); + fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str()); fprintf(stderr, "\n"); } @@ -167,6 +172,8 @@ std::string transcribe( wparams.initial_prompt = params.context.data(); + wparams.suppress_regex = params.suppress_regex.c_str(); + const auto & grammar_parsed = params.grammar_parsed; auto grammar_rules = grammar_parsed.c_rules(); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 42b067e718d..af8b5ca4e01 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -78,6 +79,9 @@ struct whisper_params { // [TDRZ] speaker turn string std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line + // A regular expression that matches tokens to suppress + std::string suppress_regex; + std::string openvino_encode_device = "CPU"; std::string dtw = ""; @@ -160,6 +164,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; } else if ( arg == "--grammar") { params.grammar = argv[++i]; } else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; } else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); } @@ -223,6 +228,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str()); fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str()); fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); @@ -1033,6 +1039,8 @@ int main(int argc, char ** argv) { wparams.tdrz_enable = params.tinydiarize; // [TDRZ] + wparams.suppress_regex = params.suppress_regex.c_str(); + wparams.initial_prompt = params.prompt.c_str(); wparams.greedy.best_of = params.best_of; diff --git a/whisper.cpp b/whisper.cpp index d50c788b3c6..fd9737379db 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -4553,6 +4553,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.tdrz_enable =*/ false, + /* suppress_regex =*/ nullptr, + /*.initial_prompt =*/ nullptr, /*.prompt_tokens =*/ nullptr, /*.prompt_n_tokens =*/ 0, @@ -4796,6 +4798,17 @@ static void whisper_process_logits( params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); } + // suppress any tokens matching a regular expression + // ref: https://github.com/openai/whisper/discussions/1041 + if (params.suppress_regex != nullptr) { + std::regex re(params.suppress_regex); + for (std::pair token_id : vocab.token_to_id) { + if (std::regex_match(token_id.first, re)) { + logits[token_id.second] = -INFINITY; + } + } + } + // suppress non-speech tokens // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 if (params.suppress_non_speech_tokens) { diff --git a/whisper.h b/whisper.h index bd8d8df828a..6a875d3bbb9 100644 --- a/whisper.h +++ b/whisper.h @@ -505,6 +505,9 @@ extern "C" { // [EXPERIMENTAL] [TDRZ] tinydiarize bool tdrz_enable; // enable tinydiarize speaker turn detection + // A regular expression that matches tokens to suppress + const char * suppress_regex; + // tokens to provide to the whisper decoder as initial prompt // these are prepended to any existing text context from a previous call // use whisper_tokenize() to convert text to tokens