Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

common : add HF arg helpers #6234

Merged
merged 2 commits into from Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
83 changes: 71 additions & 12 deletions common/common.cpp
Expand Up @@ -647,6 +647,22 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg,
params.model = argv[i];
return true;
}
if (arg == "-md" || arg == "--model-draft") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.model_draft = argv[i];
return true;
}
if (arg == "-a" || arg == "--alias") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.model_alias = argv[i];
return true;
}
if (arg == "-mu" || arg == "--model-url") {
if (++i >= argc) {
invalid_param = true;
Expand All @@ -655,20 +671,20 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg,
params.model_url = argv[i];
return true;
}
if (arg == "-md" || arg == "--model-draft") {
if (arg == "-hfr" || arg == "--hf-repo") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.model_draft = argv[i];
params.hf_repo = argv[i];
return true;
}
if (arg == "-a" || arg == "--alias") {
if (arg == "-hff" || arg == "--hf-file") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.model_alias = argv[i];
params.hf_file = argv[i];
return true;
}
if (arg == "--lora") {
Expand Down Expand Up @@ -1403,10 +1419,14 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" layer range to apply the control vector(s) to, start and end inclusive\n");
printf(" -m FNAME, --model FNAME\n");
printf(" model path (default: %s)\n", params.model.c_str());
printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
printf(" model download url (default: %s)\n", params.model_url.c_str());
printf(" -md FNAME, --model-draft FNAME\n");
printf(" draft model for speculative decoding\n");
printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
printf(" model download url (default: %s)\n", params.model_url.c_str());
printf(" -hfr REPO, --hf-repo REPO\n");
printf(" Hugging Face model repository (default: %s)\n", params.hf_repo.c_str());
printf(" -hff FILE, --hf-file FILE\n");
printf(" Hugging Face model file (default: %s)\n", params.hf_file.c_str());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default probably does not need to be printed here, since it will always be an empty string, and I do not expect that we will ever change that.

printf(" -ld LOGDIR, --logdir LOGDIR\n");
printf(" path under which to save YAML logs (no logging if unset)\n");
printf(" --override-kv KEY=TYPE:VALUE\n");
Expand Down Expand Up @@ -1655,8 +1675,10 @@ void llama_batch_add(

#ifdef LLAMA_USE_CURL

struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model,
struct llama_model_params params) {
struct llama_model * llama_load_model_from_url(
const char * model_url,
const char * path_model,
const struct llama_model_params & params) {
// Basic validation of the model_url
if (!model_url || strlen(model_url) == 0) {
fprintf(stderr, "%s: invalid model_url\n", __func__);
Expand Down Expand Up @@ -1850,25 +1872,62 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha
return llama_load_model_from_file(path_model, params);
}

struct llama_model * llama_load_model_from_hf(
const char * repo,
const char * model,
const char * path_model,
const struct llama_model_params & params) {
// construct hugging face model url:
//
// --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
// https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
//
// --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
// https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
//

std::string model_url = "https://huggingface.co/";
model_url += repo;
model_url += "/resolve/main/";
model_url += model;

return llama_load_model_from_url(model_url.c_str(), path_model, params);
}

#else

struct llama_model * llama_load_model_from_url(const char * /*model_url*/, const char * /*path_model*/,
struct llama_model_params /*params*/) {
struct llama_model * llama_load_model_from_url(
const char * /*model_url*/,
const char * /*path_model*/,
const struct llama_model_params & /*params*/) {
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
return nullptr;
}

struct llama_model * llama_load_model_from_hf(
const char * /*repo*/,
const char * /*model*/,
const char * /*path_model*/,
const struct llama_model_params & /*params*/) {
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
return nullptr;
}

#endif // LLAMA_USE_CURL

std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto mparams = llama_model_params_from_gpt_params(params);

llama_model * model = nullptr;
if (!params.model_url.empty()) {

if (!params.hf_repo.empty() && !params.hf_file.empty()) {
model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams);
} else if (!params.model_url.empty()) {
model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
} else {
model = llama_load_model_from_file(params.model.c_str(), mparams);
}

if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return std::make_tuple(nullptr, nullptr);
Expand Down Expand Up @@ -1908,7 +1967,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
}

for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]);
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
float lora_scale = std::get<1>(params.lora_adapter[i]);
int err = llama_model_apply_lora_from_file(model,
lora_adapter.c_str(),
Expand Down
10 changes: 6 additions & 4 deletions common/common.h
Expand Up @@ -89,9 +89,11 @@ struct gpt_params {
struct llama_sampling_params sparams;

std::string model = "models/7B/ggml-model-f16.gguf"; // model path
std::string model_url = ""; // model url to download
std::string model_draft = ""; // draft model for speculative decoding
std::string model_draft = ""; // draft model for speculative decoding
std::string model_alias = "unknown"; // model alias
std::string model_url = ""; // model url to download
std::string hf_repo = ""; // HF repo
std::string hf_file = ""; // HF file
std::string prompt = "";
std::string prompt_file = ""; // store the external prompt file name
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
Expand Down Expand Up @@ -192,8 +194,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);

struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model,
struct llama_model_params params);
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params);
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params);

// Batch utils

Expand Down