Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented command-style grammar in the main example. #1998

Merged
merged 2 commits into from Mar 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
58 changes: 54 additions & 4 deletions examples/main/main.cpp
@@ -1,6 +1,7 @@
#include "common.h"

#include "whisper.h"
#include "grammar-parser.h"

#include <cmath>
#include <fstream>
Expand Down Expand Up @@ -38,9 +39,10 @@ struct whisper_params {
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
int32_t audio_ctx = 0;

float word_thold = 0.01f;
float entropy_thold = 2.40f;
float logprob_thold = -1.00f;
float word_thold = 0.01f;
float entropy_thold = 2.40f;
float logprob_thold = -1.00f;
float grammar_penalty = 100.0f;

bool speed_up = false;
bool debug_mode = false;
Expand Down Expand Up @@ -70,6 +72,8 @@ struct whisper_params {
std::string prompt;
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
std::string model = "models/ggml-base.en.bin";
std::string grammar;
std::string grammar_rule;

// [TDRZ] speaker turn string
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
Expand All @@ -80,6 +84,8 @@ struct whisper_params {

std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_out = {};

grammar_parser::parse_state grammar_parsed;
};

void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
Expand Down Expand Up @@ -154,6 +160,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
Expand Down Expand Up @@ -214,6 +223,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
fprintf(stderr, "\n");
}

Expand Down Expand Up @@ -926,6 +938,29 @@ int main(int argc, char ** argv) {
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);

if (!params.grammar.empty()) {
auto & grammar = params.grammar_parsed;
if (is_file_exist(params.grammar.c_str())) {
// read grammar from file
std::ifstream ifs(params.grammar.c_str());
const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
grammar = grammar_parser::parse(txt.c_str());
} else {
// read grammar from string
grammar = grammar_parser::parse(params.grammar.c_str());
}

// will be empty (default) if there are parse errors
if (grammar.rules.empty()) {
fprintf(stderr, "error: failed to parse grammar \"%s\"\n", params.grammar.c_str());
return 4;
} else {
fprintf(stderr, "%s: grammar:\n", __func__);
grammar_parser::print_grammar(stderr, grammar);
fprintf(stderr, "\n");
}
}

for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
Expand Down Expand Up @@ -972,7 +1007,8 @@ int main(int argc, char ** argv) {
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);

wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
const bool use_grammar = (!params.grammar_parsed.rules.empty() && !params.grammar_rule.empty());
wparams.strategy = (params.beam_size > 1 || use_grammar) ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;

wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
Expand Down Expand Up @@ -1010,6 +1046,20 @@ int main(int argc, char ** argv) {

whisper_print_user_data user_data = { &params, &pcmf32s, 0 };

const auto & grammar_parsed = params.grammar_parsed;
auto grammar_rules = grammar_parsed.c_rules();

if (use_grammar) {
if (grammar_parsed.symbol_ids.find(params.grammar_rule) == grammar_parsed.symbol_ids.end()) {
fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, params.grammar_rule.c_str());
} else {
wparams.grammar_rules = grammar_rules.data();
wparams.n_grammar_rules = grammar_rules.size();
wparams.i_start_rule = grammar_parsed.symbol_ids.at(params.grammar_rule);
wparams.grammar_penalty = params.grammar_penalty;
}
}

// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
Expand Down