From 4d116abdad89bb926fadfd4060a870e62875974c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Arrufat?= Date: Wed, 22 Jan 2025 20:41:45 +0900 Subject: [PATCH] minor tweaks to the transformer example --- examples/slm_basic_train_ex.cpp | 39 ++++++++++++++++++--------------- examples/slm_data.h | 4 ++-- examples/slm_defs.h | 10 ++++----- 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/examples/slm_basic_train_ex.cpp b/examples/slm_basic_train_ex.cpp index e4322fc774..010fbf49db 100644 --- a/examples/slm_basic_train_ex.cpp +++ b/examples/slm_basic_train_ex.cpp @@ -49,6 +49,9 @@ // ---------------------------------------------------------------------------------------- +using namespace std; +using namespace dlib; + // We treat each character as a token ID in [0..255]. const int MAX_TOKEN_ID = 255; const int PAD_TOKEN = 256; // an extra "pad" token if needed @@ -66,13 +69,13 @@ std::vector char_based_tokenize(const std::string& text) } // Function to shuffle samples and labels in sync -void shuffle_samples_and_labels(std::vector>& samples, std::vector& labels) { +void shuffle_samples_and_labels(std::vector>& samples, std::vector& labels) { std::vector indices(samples.size()); std::iota(indices.begin(), indices.end(), 0); // Fill with 0, 1, 2, ..., N-1 std::shuffle(indices.begin(), indices.end(), std::default_random_engine{}); // Create temporary vectors to hold shuffled data - std::vector> shuffled_samples(samples.size()); + std::vector> shuffled_samples(samples.size()); std::vector shuffled_labels(labels.size()); // Apply the shuffle @@ -93,15 +96,15 @@ int main(int argc, char** argv) { try { - dlib::command_line_parser parser; + command_line_parser parser; parser.add_option("train", "Train a small transformer on the built-in Shakespeare text"); parser.add_option("generate", "Generate text from a previously trained model (needs shakespeare_prompt)"); parser.add_option("learning-rate", "Set the learning rate for training (default: 1e-4)", 1); parser.add_option("batch-size", "Set the mini-batch size for training (default: 64)", 1); parser.add_option("generation-length", "Set the length of generated text (default: 400)", 1); - parser.add_option("alpha", "Set the initial learning rate for Adam optimizer (default: 0.004)", 1); - parser.add_option("beta1", "Set the decay rate for the first moment estimate (default: 0.9)", 1); - parser.add_option("beta2", "Set the decay rate for the second moment estimate (default: 0.999)", 1); + parser.add_option("alpha", "Set the weight decay for Adam optimizer (default: 0.004)", 1); + parser.add_option("beta1", "Set the first moment coefficient (default: 0.9)", 1); + parser.add_option("beta2", "Set the second moment coefficient (default: 0.999)", 1); parser.add_option("max-samples", "Set the maximum number of training samples (default: 50000)", 1); parser.add_option("shuffle", "Shuffle training sequences and labels before training (default: false)"); parser.parse(argc, argv); @@ -122,7 +125,7 @@ int main(int argc, char** argv) const size_t max_samples = get_option(parser, "max-samples",50000); // Default maximum number of training samples // We define a minimal config for demonstration - const long vocab_size = 257; // 0..255 for chars + 1 pad token + const long vocab_size = MAX_TOKEN_ID + 1 + 1; // 256 for chars + 1 pad token const long num_layers = 3; const long num_heads = 4; const long embedding_dim = 64; @@ -136,8 +139,8 @@ int main(int argc, char** argv) embedding_dim, max_seq_len, use_squeezing, - dlib::gelu, - dlib::dropout_10 + gelu, + dropout_10 >; // For GPU usage (if any), set gpus = {0} for a single GPU, etc. @@ -181,7 +184,7 @@ int main(int argc, char** argv) return 0; } - std::vector> samples; + std::vector> samples; std::vector labels; // Let's create a training set of about (N) samples from the text @@ -190,7 +193,7 @@ int main(int argc, char** argv) const size_t N = (max_sequences < max_samples) ? max_sequences : max_samples; for (size_t start = 0; start < N; ++start) { - dlib::matrix seq(max_seq_len, 1); + matrix seq(max_seq_len, 1); for (long t = 0; t < max_seq_len; ++t) seq(t, 0) = full_tokens[start + t]; samples.push_back(seq); @@ -207,11 +210,11 @@ int main(int argc, char** argv) // 3) Construct the network in training mode using net_type = my_transformer_cfg::network_type; net_type net; - if (dlib::file_exists(model_file)) - dlib::deserialize(model_file) >> net; + if (file_exists(model_file)) + deserialize(model_file) >> net; // 4) Create dnn_trainer - dlib::dnn_trainer trainer(net, dlib::adam(alpha, beta1, beta2), gpus); + dnn_trainer trainer(net, adam(alpha, beta1, beta2), gpus); trainer.set_learning_rate(learning_rate); trainer.set_min_learning_rate(1e-6); trainer.set_mini_batch_size(batch_size); @@ -233,7 +236,7 @@ int main(int argc, char** argv) // 7) Save the model net.clean(); - dlib::serialize(model_file) << net; + serialize(model_file) << net; std::cout << "Model saved to " << model_file << "\n"; } @@ -246,9 +249,9 @@ int main(int argc, char** argv) // 1) Load the trained model using net_infer = my_transformer_cfg::network_type; net_infer net; - if (dlib::file_exists(model_file)) + if (file_exists(model_file)) { - dlib::deserialize(model_file) >> net; + deserialize(model_file) >> net; std::cout << "Loaded model from " << model_file << "\n"; } else @@ -274,7 +277,7 @@ int main(int argc, char** argv) const auto prompt_tokens = char_based_tokenize(prompt_text); // Put into a dlib matrix - dlib::matrix input_seq(max_seq_len, 1); + matrix input_seq(max_seq_len, 1); // Fill with pad if prompt is shorter than max_seq_len for (long i = 0; i < max_seq_len; ++i) { diff --git a/examples/slm_data.h b/examples/slm_data.h index 86f3bdcc10..37c08f29b8 100644 --- a/examples/slm_data.h +++ b/examples/slm_data.h @@ -6,7 +6,7 @@ #include // Utility function to concatenate text parts -std::string concatenateTexts(const std::vector& texts) { +inline std::string concatenateTexts(const std::vector& texts) { std::string result; for (const auto& text : texts) { result += text; @@ -590,4 +590,4 @@ And you shall understand from me her mind. )"; -#endif // SlmData_H \ No newline at end of file +#endif // SlmData_H diff --git a/examples/slm_defs.h b/examples/slm_defs.h index 786b1ffdd7..d556fc0dab 100644 --- a/examples/slm_defs.h +++ b/examples/slm_defs.h @@ -214,11 +214,11 @@ namespace transformer template using network_type = std::conditional_t>>>>, + repeat>>>>, classification_head>>>> + repeat>>>> >; /** @@ -283,4 +283,4 @@ namespace transformer */ } -#endif // SlmNet_H \ No newline at end of file +#endif // SlmNet_H