Skip to content

Commit

Permalink
whisper : add initial_prompt param (ggerganov#645)
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Mar 29, 2023
1 parent c2b23b3 commit 3a6c2a9
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 36 deletions.
19 changes: 1 addition & 18 deletions examples/addon.node/addon.cpp
Expand Up @@ -160,22 +160,6 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
return 3;
}

// initial prompt
std::vector<whisper_token> prompt_tokens;

if (!params.prompt.empty()) {
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));

fprintf(stderr, "\n");
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
fprintf(stderr, "initial tokens: [ ");
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
fprintf(stderr, "%d ", prompt_tokens[i]);
}
fprintf(stderr, "]\n");
}

for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
Expand Down Expand Up @@ -243,8 +227,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;

wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.initial_prompt = params.prompt.c_str();

whisper_print_user_data user_data = { &params, &pcmf32s };

Expand Down
19 changes: 1 addition & 18 deletions examples/main/main.cpp
Expand Up @@ -639,22 +639,6 @@ int main(int argc, char ** argv) {
return 3;
}

// initial prompt
std::vector<whisper_token> prompt_tokens;

if (!params.prompt.empty()) {
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));

fprintf(stderr, "\n");
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
fprintf(stderr, "initial tokens: [ ");
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
fprintf(stderr, "%d ", prompt_tokens[i]);
}
fprintf(stderr, "]\n");
}

for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
Expand Down Expand Up @@ -718,8 +702,7 @@ int main(int argc, char ** argv) {

wparams.speed_up = params.speed_up;

wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.initial_prompt = params.prompt.c_str();

wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
Expand Down
10 changes: 10 additions & 0 deletions whisper.cpp
Expand Up @@ -3121,6 +3121,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.speed_up =*/ false,
/*.audio_ctx =*/ 0,

/*.initial_prompt =*/ nullptr,
/*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,

Expand Down Expand Up @@ -3793,6 +3794,15 @@ int whisper_full_with_state(
prompt_past.clear();
}

// initial prompt
if (!params.prompt_tokens && params.initial_prompt) {
std::vector<whisper_token> prompt_tokens;
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
params.prompt_tokens = prompt_tokens.data();
params.prompt_n_tokens = prompt_tokens.size();
}

// prepend the prompt tokens to the prompt_past
if (params.prompt_tokens && params.prompt_n_tokens > 0) {
// parse tokens from the pointer
Expand Down
1 change: 1 addition & 0 deletions whisper.h
Expand Up @@ -356,6 +356,7 @@ extern "C" {

// tokens to provide to the whisper decoder as initial prompt
// these are prepended to any existing text context from a previous call
const char * initial_prompt;
const whisper_token * prompt_tokens;
int prompt_n_tokens;

Expand Down

0 comments on commit 3a6c2a9

Please sign in to comment.