Skip to content

Commit dfb441b

Browse files
committed
added command line arugment for p-rng seed
1 parent 501f0da commit dfb441b

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

main.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <string>
2424
#include <vector>
2525
#include <random>
26+
#include <chrono>
2627
#include <functional>
2728

2829
#if defined(_MSC_VER)
@@ -34,7 +35,16 @@
3435

3536
int32_t NUM_RETURN_SEQUENCES = 4; //hardcoding this for now, analagous to "num_return_sequences arugment to inference_speech"
3637

37-
std::mt19937 generator(245645656);
38+
39+
auto now = std::chrono::system_clock::now();
40+
auto duration = now.time_since_epoch();
41+
auto milliseconds = std::chrono::duration_cast<std::chrono::milliseconds>(duration);
42+
43+
// Use the milliseconds count as a seed for the random number generator
44+
unsigned time_seed = milliseconds.count();
45+
46+
std::mt19937 generator(time_seed);
47+
3848
std::uniform_real_distribution<float> distribution(0.0, 1.0);
3949
std::normal_distribution<double> normal_distribution(0.0,1.0);
4050

@@ -5886,6 +5896,7 @@ bool mel_code_vectors_match(const std::vector<std::vector<int>>& vec1, const std
58865896

58875897
void test_autoregressive(){
58885898

5899+
generator.seed(245645656);
58895900

58905901
std::vector<gpt_vocab::id> tokens = ::parse_tokens_from_string("255,15,55,49,9,9,9,2,134,16,51,31,2,19,46,18,176,13,0,0", ','); //"Based... Dr. Freeman?"
58915902

@@ -5975,6 +5986,8 @@ int main(int argc, char ** argv) {
59755986
message = argv[i + 1];
59765987
} else if (std::string(argv[i]) == "--output") {
59775988
outputPath = argv[i + 1];
5989+
}else if (std::string(argv[i]) == "--seed") {
5990+
generator.seed(std::stoi(argv[i + 1]));
59785991
}
59795992
}
59805993

0 commit comments

Comments
 (0)