diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 415c3b33de7..42b067e718d 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -1,6 +1,7 @@ #include "common.h" #include "whisper.h" +#include "grammar-parser.h" #include #include @@ -38,9 +39,10 @@ struct whisper_params { int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size; int32_t audio_ctx = 0; - float word_thold = 0.01f; - float entropy_thold = 2.40f; - float logprob_thold = -1.00f; + float word_thold = 0.01f; + float entropy_thold = 2.40f; + float logprob_thold = -1.00f; + float grammar_penalty = 100.0f; bool speed_up = false; bool debug_mode = false; @@ -70,6 +72,8 @@ struct whisper_params { std::string prompt; std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; std::string model = "models/ggml-base.en.bin"; + std::string grammar; + std::string grammar_rule; // [TDRZ] speaker turn string std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line @@ -80,6 +84,8 @@ struct whisper_params { std::vector fname_inp = {}; std::vector fname_out = {}; + + grammar_parser::parse_state grammar_parsed; }; void whisper_print_usage(int argc, char ** argv, const whisper_params & params); @@ -154,6 +160,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if ( arg == "--grammar") { params.grammar = argv[++i]; } + else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; } + else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -214,6 +223,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); + fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str()); + fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); fprintf(stderr, "\n"); } @@ -926,6 +938,29 @@ int main(int argc, char ** argv) { // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr); + if (!params.grammar.empty()) { + auto & grammar = params.grammar_parsed; + if (is_file_exist(params.grammar.c_str())) { + // read grammar from file + std::ifstream ifs(params.grammar.c_str()); + const std::string txt = std::string((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); + grammar = grammar_parser::parse(txt.c_str()); + } else { + // read grammar from string + grammar = grammar_parser::parse(params.grammar.c_str()); + } + + // will be empty (default) if there are parse errors + if (grammar.rules.empty()) { + fprintf(stderr, "error: failed to parse grammar \"%s\"\n", params.grammar.c_str()); + return 4; + } else { + fprintf(stderr, "%s: grammar:\n", __func__); + grammar_parser::print_grammar(stderr, grammar); + fprintf(stderr, "\n"); + } + } + for (int f = 0; f < (int) params.fname_inp.size(); ++f) { const auto fname_inp = params.fname_inp[f]; const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; @@ -972,7 +1007,8 @@ int main(int argc, char ** argv) { { whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; + const bool use_grammar = (!params.grammar_parsed.rules.empty() && !params.grammar_rule.empty()); + wparams.strategy = (params.beam_size > 1 || use_grammar) ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; wparams.print_realtime = false; wparams.print_progress = params.print_progress; @@ -1010,6 +1046,20 @@ int main(int argc, char ** argv) { whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 }; + const auto & grammar_parsed = params.grammar_parsed; + auto grammar_rules = grammar_parsed.c_rules(); + + if (use_grammar) { + if (grammar_parsed.symbol_ids.find(params.grammar_rule) == grammar_parsed.symbol_ids.end()) { + fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, params.grammar_rule.c_str()); + } else { + wparams.grammar_rules = grammar_rules.data(); + wparams.n_grammar_rules = grammar_rules.size(); + wparams.i_start_rule = grammar_parsed.symbol_ids.at(params.grammar_rule); + wparams.grammar_penalty = params.grammar_penalty; + } + } + // this callback is called on each new segment if (!wparams.print_realtime) { wparams.new_segment_callback = whisper_print_segment_callback;