Skip to content

Commit

Permalink
minor tweaks to the transformer example
Browse files Browse the repository at this point in the history
  • Loading branch information
arrufat committed Jan 22, 2025
1 parent 8fdd2a6 commit 4d116ab
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 25 deletions.
39 changes: 21 additions & 18 deletions examples/slm_basic_train_ex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@

// ----------------------------------------------------------------------------------------

using namespace std;
using namespace dlib;

// We treat each character as a token ID in [0..255].
const int MAX_TOKEN_ID = 255;
const int PAD_TOKEN = 256; // an extra "pad" token if needed
Expand All @@ -66,13 +69,13 @@ std::vector<int> char_based_tokenize(const std::string& text)
}

// Function to shuffle samples and labels in sync
void shuffle_samples_and_labels(std::vector<dlib::matrix<int, 0, 1>>& samples, std::vector<unsigned long>& labels) {
void shuffle_samples_and_labels(std::vector<matrix<int, 0, 1>>& samples, std::vector<unsigned long>& labels) {
std::vector<size_t> indices(samples.size());
std::iota(indices.begin(), indices.end(), 0); // Fill with 0, 1, 2, ..., N-1
std::shuffle(indices.begin(), indices.end(), std::default_random_engine{});

// Create temporary vectors to hold shuffled data
std::vector<dlib::matrix<int, 0, 1>> shuffled_samples(samples.size());
std::vector<matrix<int, 0, 1>> shuffled_samples(samples.size());
std::vector<unsigned long> shuffled_labels(labels.size());

// Apply the shuffle
Expand All @@ -93,15 +96,15 @@ int main(int argc, char** argv)
{
try
{
dlib::command_line_parser parser;
command_line_parser parser;
parser.add_option("train", "Train a small transformer on the built-in Shakespeare text");
parser.add_option("generate", "Generate text from a previously trained model (needs shakespeare_prompt)");
parser.add_option("learning-rate", "Set the learning rate for training (default: 1e-4)", 1);
parser.add_option("batch-size", "Set the mini-batch size for training (default: 64)", 1);
parser.add_option("generation-length", "Set the length of generated text (default: 400)", 1);
parser.add_option("alpha", "Set the initial learning rate for Adam optimizer (default: 0.004)", 1);
parser.add_option("beta1", "Set the decay rate for the first moment estimate (default: 0.9)", 1);
parser.add_option("beta2", "Set the decay rate for the second moment estimate (default: 0.999)", 1);
parser.add_option("alpha", "Set the weight decay for Adam optimizer (default: 0.004)", 1);
parser.add_option("beta1", "Set the first moment coefficient (default: 0.9)", 1);
parser.add_option("beta2", "Set the second moment coefficient (default: 0.999)", 1);
parser.add_option("max-samples", "Set the maximum number of training samples (default: 50000)", 1);
parser.add_option("shuffle", "Shuffle training sequences and labels before training (default: false)");
parser.parse(argc, argv);
Expand All @@ -122,7 +125,7 @@ int main(int argc, char** argv)
const size_t max_samples = get_option(parser, "max-samples",50000); // Default maximum number of training samples

// We define a minimal config for demonstration
const long vocab_size = 257; // 0..255 for chars + 1 pad token
const long vocab_size = MAX_TOKEN_ID + 1 + 1; // 256 for chars + 1 pad token
const long num_layers = 3;
const long num_heads = 4;
const long embedding_dim = 64;
Expand All @@ -136,8 +139,8 @@ int main(int argc, char** argv)
embedding_dim,
max_seq_len,
use_squeezing,
dlib::gelu,
dlib::dropout_10
gelu,
dropout_10
>;

// For GPU usage (if any), set gpus = {0} for a single GPU, etc.
Expand Down Expand Up @@ -181,7 +184,7 @@ int main(int argc, char** argv)
return 0;
}

std::vector<dlib::matrix<int, 0, 1>> samples;
std::vector<matrix<int, 0, 1>> samples;
std::vector<unsigned long> labels;

// Let's create a training set of about (N) samples from the text
Expand All @@ -190,7 +193,7 @@ int main(int argc, char** argv)
const size_t N = (max_sequences < max_samples) ? max_sequences : max_samples;
for (size_t start = 0; start < N; ++start)
{
dlib::matrix<int, 0, 1> seq(max_seq_len, 1);
matrix<int, 0, 1> seq(max_seq_len, 1);
for (long t = 0; t < max_seq_len; ++t)
seq(t, 0) = full_tokens[start + t];
samples.push_back(seq);
Expand All @@ -207,11 +210,11 @@ int main(int argc, char** argv)
// 3) Construct the network in training mode
using net_type = my_transformer_cfg::network_type<true>;
net_type net;
if (dlib::file_exists(model_file))
dlib::deserialize(model_file) >> net;
if (file_exists(model_file))
deserialize(model_file) >> net;

// 4) Create dnn_trainer
dlib::dnn_trainer<net_type, dlib::adam> trainer(net, dlib::adam(alpha, beta1, beta2), gpus);
dnn_trainer<net_type, adam> trainer(net, adam(alpha, beta1, beta2), gpus);
trainer.set_learning_rate(learning_rate);
trainer.set_min_learning_rate(1e-6);
trainer.set_mini_batch_size(batch_size);
Expand All @@ -233,7 +236,7 @@ int main(int argc, char** argv)

// 7) Save the model
net.clean();
dlib::serialize(model_file) << net;
serialize(model_file) << net;
std::cout << "Model saved to " << model_file << "\n";
}

Expand All @@ -246,9 +249,9 @@ int main(int argc, char** argv)
// 1) Load the trained model
using net_infer = my_transformer_cfg::network_type<false>;
net_infer net;
if (dlib::file_exists(model_file))
if (file_exists(model_file))
{
dlib::deserialize(model_file) >> net;
deserialize(model_file) >> net;
std::cout << "Loaded model from " << model_file << "\n";
}
else
Expand All @@ -274,7 +277,7 @@ int main(int argc, char** argv)
const auto prompt_tokens = char_based_tokenize(prompt_text);

// Put into a dlib matrix
dlib::matrix<int, 0, 1> input_seq(max_seq_len, 1);
matrix<int, 0, 1> input_seq(max_seq_len, 1);
// Fill with pad if prompt is shorter than max_seq_len
for (long i = 0; i < max_seq_len; ++i)
{
Expand Down
4 changes: 2 additions & 2 deletions examples/slm_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <algorithm>

// Utility function to concatenate text parts
std::string concatenateTexts(const std::vector<std::string>& texts) {
inline std::string concatenateTexts(const std::vector<std::string>& texts) {
std::string result;
for (const auto& text : texts) {
result += text;
Expand Down Expand Up @@ -590,4 +590,4 @@ And you shall understand from me her mind.
)";

#endif // SlmData_H
#endif // SlmData_H
10 changes: 5 additions & 5 deletions examples/slm_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,11 @@ namespace transformer
template<bool is_training>
using network_type = std::conditional_t<is_training,
classification_head<USE_SQUEEZING, activation_func, VOCAB_SIZE, EMBEDDING_DIM,
repeat<NUM_LAYERS, t_transformer_block,
positional_embeddings<VOCAB_SIZE, EMBEDDING_DIM, input<matrix<int, 0, 1>>>>>,
repeat<NUM_LAYERS, t_transformer_block,
positional_embeddings<VOCAB_SIZE, EMBEDDING_DIM, input<matrix<int, 0, 1>>>>>,
classification_head<USE_SQUEEZING, activation_func, VOCAB_SIZE, EMBEDDING_DIM,
repeat<NUM_LAYERS, i_transformer_block,
positional_embeddings<VOCAB_SIZE, EMBEDDING_DIM, input<matrix<int, 0, 1>>>>>
repeat<NUM_LAYERS, i_transformer_block,
positional_embeddings<VOCAB_SIZE, EMBEDDING_DIM, input<matrix<int, 0, 1>>>>>
>;

/**
Expand Down Expand Up @@ -283,4 +283,4 @@ namespace transformer
*/
}

#endif // SlmNet_H
#endif // SlmNet_H

0 comments on commit 4d116ab

Please sign in to comment.