@@ -4710,7 +4710,7 @@ std::vector<float> load_f32_vector(const std::string& filename, size_t nBytes) {
4710
4710
}
4711
4711
4712
4712
4713
- std::pair<std::vector<std::vector<float >>, std::vector<std::vector<int >>> autoregressive (std::vector<gpt_vocab::id> tokens)
4713
+ std::pair<std::vector<std::vector<float >>, std::vector<std::vector<int >>> autoregressive (std::vector<gpt_vocab::id> tokens, std::string voice_path )
4714
4714
{
4715
4715
4716
4716
@@ -4870,7 +4870,7 @@ std::pair<std::vector<std::vector<float>>, std::vector<std::vector<int>>> autore
4870
4870
struct ggml_tensor *auto_conditioning_tensor = ggml_graph_get_tensor (gf, " auto_conditioning" );
4871
4871
4872
4872
4873
- std::vector<float > auto_conditioning_vector = load_f32_vector (" ../models/mol.bin " , 4096 );
4873
+ std::vector<float > auto_conditioning_vector = load_f32_vector (voice_path , 4096 );
4874
4874
4875
4875
ggml_backend_tensor_set (auto_conditioning_tensor, auto_conditioning_vector.data (), 0 , 1024 *ggml_element_size (auto_conditioning_tensor));
4876
4876
@@ -5889,7 +5889,7 @@ void test_autoregressive(){
5889
5889
5890
5890
std::vector<gpt_vocab::id> tokens = ::parse_tokens_from_string (" 255,15,55,49,9,9,9,2,134,16,51,31,2,19,46,18,176,13,0,0" , ' ,' ); // "Based... Dr. Freeman?"
5891
5891
5892
- std::pair<std::vector<std::vector<float >>, std::vector<std::vector<int >>> autoregressive_result = autoregressive (tokens);
5892
+ std::pair<std::vector<std::vector<float >>, std::vector<std::vector<int >>> autoregressive_result = autoregressive (tokens, " ../models/mol.bin " );
5893
5893
5894
5894
std::vector<std::vector<float >> trimmed_latents = autoregressive_result.first ;
5895
5895
std::vector<std::vector<int >> sequences = autoregressive_result.second ;
@@ -5960,6 +5960,29 @@ int main(int argc, char ** argv) {
5960
5960
5961
5961
5962
5962
5963
+ std::string defaultMessage = " this is a test message." ;
5964
+ std::string defaultVoicePath = " ../models/mol.bin" ;
5965
+ std::string defaultOutputPath = " ./output.wav" ;
5966
+ std::string message = defaultMessage;
5967
+ std::string voicePath = defaultVoicePath;
5968
+ std::string outputPath = defaultOutputPath;
5969
+
5970
+ // Parse command line arguments
5971
+ for (int i = 1 ; i < argc - 1 ; ++i) {
5972
+ if (std::string (argv[i]) == " --voice" ) {
5973
+ voicePath = argv[i + 1 ];
5974
+ } else if (std::string (argv[i]) == " --message" ) {
5975
+ message = argv[i + 1 ];
5976
+ } else if (std::string (argv[i]) == " --output" ) {
5977
+ outputPath = argv[i + 1 ];
5978
+ }
5979
+ }
5980
+
5981
+ // Use the updated values
5982
+ std::cout << " Message: " << message << std::endl;
5983
+ std::cout << " Voice Path: " << voicePath << std::endl;
5984
+
5985
+
5963
5986
5964
5987
gpt_vocab vocab;
5965
5988
gpt_vocab_init (" ../models/tokenizer.json" , vocab);
@@ -5972,7 +5995,7 @@ int main(int argc, char ** argv) {
5972
5995
5973
5996
// std::string message = "this[SPACE]is[SPACE]a[SPACE]test[SPACE]message";
5974
5997
// std::string message = "tortoise, full process complete.";
5975
- std::string message = " minimum viable product incoming." ;
5998
+ // std::string message = "minimum viable product incoming.";
5976
5999
5977
6000
5978
6001
replaceAll (message, " " ," [SPACE]" );
@@ -5996,7 +6019,7 @@ int main(int argc, char ** argv) {
5996
6019
// std::vector<gpt_vocab::id> tokens = ::parse_tokens_from_string("255,15,55,49,9,9,9,2,134,16,51,31,2,19,46,18,176,13,0,0", ','); //"Based... Dr. Freeman?"
5997
6020
5998
6021
5999
- std::pair<std::vector<std::vector<float >>, std::vector<std::vector<int >>> autoregressive_result = autoregressive (tokens);
6022
+ std::pair<std::vector<std::vector<float >>, std::vector<std::vector<int >>> autoregressive_result = autoregressive (tokens,voicePath );
6000
6023
6001
6024
std::vector<std::vector<float >> trimmed_latents = autoregressive_result.first ;
6002
6025
std::vector<std::vector<int >> sequences = autoregressive_result.second ;
@@ -6100,7 +6123,7 @@ int main(int argc, char ** argv) {
6100
6123
extract_tensor_to_vector ( vocoder_gf->nodes [vocoder_gf->n_nodes -1 ] , audio);
6101
6124
6102
6125
6103
- writeWav (" based?.wav " , audio , 24000 );
6126
+ writeWav (outputPath. c_str () , audio , 24000 );
6104
6127
6105
6128
6106
6129
0 commit comments