Skip to content

Commit

Permalink
cmdline option for custom amount of model parts (--n_parts N)
Browse files Browse the repository at this point in the history
  • Loading branch information
anzz1 committed Mar 21, 2023
1 parent 8cf9f34 commit b839231
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
9 changes: 5 additions & 4 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ struct llama_model {
};

// load the model's weights from a file
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32) {
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, int n_parts, ggml_type memory_type = GGML_TYPE_F32) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());

std::vector<char> f_buf(1024*1024);
Expand Down Expand Up @@ -127,7 +127,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
}

int n_ff = 0;
int n_parts = 0;

// load hparams
{
Expand All @@ -145,7 +144,9 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
hparams.n_ctx = n_ctx;

n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
n_parts = LLAMA_N_PARTS.at(hparams.n_embd);

if (n_parts < 1)
n_parts = LLAMA_N_PARTS.at(hparams.n_embd);

fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_ctx = %d\n", __func__, hparams.n_ctx);
Expand Down Expand Up @@ -839,7 +840,7 @@ int main(int argc, char ** argv) {
{
const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
const int64_t t_start_us = ggml_time_us();
if (!llama_model_load(params.model, model, vocab, params.n_ctx, memory_type)) {
if (!llama_model_load(params.model, model, vocab, params.n_ctx, params.n_parts, memory_type)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}
Expand Down
3 changes: 3 additions & 0 deletions utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.antiprompt.push_back(argv[++i]);
} else if (arg == "--ignore-eos") {
params.ignore_eos = true;
} else if (arg == "--n_parts") {
params.n_parts = std::stoi(argv[++i]);
} else if (arg == "-h" || arg == "--help") {
gpt_print_usage(argc, argv, params);
exit(0);
Expand Down Expand Up @@ -116,6 +118,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n");
fprintf(stderr, " --memory_f16 use f16 instead of f32 for memory key+value\n");
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
Expand Down
1 change: 1 addition & 0 deletions utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct gpt_params {
int32_t repeat_last_n = 64; // last n tokens to penalize
int32_t n_ctx = 512; //context size
bool memory_f16 = false; // use f16 instead of f32 for memory kv
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)

// sampling parameters
int32_t top_k = 40;
Expand Down

0 comments on commit b839231

Please sign in to comment.