Skip to content

Ability to use --hf-repo without --hf-file or -m #7504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 103 additions & 35 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,18 +193,15 @@ int32_t cpu_get_num_math() {
void gpt_params_handle_model_default(gpt_params & params) {
if (!params.hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model
if (params.hf_file.empty()) {
if (params.model.empty()) {
throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n");
}
params.hf_file = params.model;
} else if (params.model.empty()) {
params.model_url = llama_get_hf_model_url(params.hf_repo, params.hf_file);
if (params.model.empty()) {
std::string cache_directory = fs_get_cache_directory();
const bool success = fs_create_directory_with_parents(cache_directory);
if (!success) {
throw std::runtime_error("failed to create cache directory: " + cache_directory);
}
params.model = cache_directory + string_split(params.hf_file, '/').back();
// TODO: cache with params.hf_repo in directory
params.model = cache_directory + string_split(params.model_url, '/').back();
}
} else if (!params.model_url.empty()) {
if (params.model.empty()) {
Expand Down Expand Up @@ -1888,9 +1885,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par

llama_model * model = nullptr;

if (!params.hf_repo.empty() && !params.hf_file.empty()) {
model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams);
} else if (!params.model_url.empty()) {
if (!params.model_url.empty()) {
model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
} else {
model = llama_load_model_from_file(params.model.c_str(), mparams);
Expand Down Expand Up @@ -2061,6 +2056,16 @@ static bool starts_with(const std::string & str, const std::string & prefix) {
return str.rfind(prefix, 0) == 0;
}

static bool ends_with(const std::string & str, const std::string & suffix) {
return str.rfind(suffix) == str.length() - suffix.length();
}

static std::string tolower(std::string s) {
std::transform(s.begin(), s.end(), s.begin(),
[](unsigned char c){ return std::tolower(c); });
return s;
}

static bool llama_download_file(const std::string & url, const std::string & path) {

// Initialize libcurl
Expand Down Expand Up @@ -2341,26 +2346,91 @@ struct llama_model * llama_load_model_from_url(
return llama_load_model_from_file(path_model, params);
}

struct llama_model * llama_load_model_from_hf(
const char * repo,
const char * model,
const char * path_model,
const struct llama_model_params & params) {
// construct hugging face model url:
//
// --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
// https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
//
// --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
// https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
//

std::string model_url = "https://huggingface.co/";
model_url += repo;
model_url += "/resolve/main/";
model_url += model;

return llama_load_model_from_url(model_url.c_str(), path_model, params);
static std::string llama_get_hf_model_url(
std::string & repo,
std::string & custom_file_path) {
std::stringstream ss;
json repo_files;

if (!custom_file_path.empty()) {
ss << "https://huggingface.co/" << repo << "/resolve/main/" << custom_file_path;
return ss.str();
}

{
// Initialize libcurl
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);

// Make the request to Hub API
ss << "https://huggingface.co/api/models/" << repo << "/tree/main?recursive=true";
std::string url = ss.str();
std::string res_str;
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
return size * nmemb;
};
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
#if defined(_WIN32)
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
CURLcode res = curl_easy_perform(curl.get());

if (res != CURLE_OK) {
fprintf(stderr, "%s: cannot make GET request to Hugging Face Hub API\n", __func__);
return nullptr;
}

long res_code;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
if (res_code != 200) {
fprintf(stderr, "%s: Hugging Face Hub API responses with status code %ld\n", __func__, res_code);
return nullptr;
} else {
repo_files = json::parse(res_str);
}
}

if (!repo_files.is_array()) {
fprintf(stderr, "%s: response from Hugging Face Hub API is not an array\nRaw response:\n%s", __func__, repo_files.dump(4).c_str());
return nullptr;
}

auto get_file_contains = [&](std::string piece) -> std::string {
for (auto elem : repo_files) {
std::string type = elem.at("type");
std::string path = elem.at("path");
if (
type == "file"
&& ends_with(path, ".gguf")
&& tolower(path).find(piece) != std::string::npos
) return path;
}
return "";
};

std::string file_path = get_file_contains("q4_k_m");
if (file_path.empty()) {
file_path = get_file_contains("q4");
}
if (file_path.empty()) {
file_path = get_file_contains("00001");
}
if (file_path.empty()) {
file_path = get_file_contains("gguf");
}

if (file_path.empty()) {
fprintf(stderr, "%s: Cannot find any gguf file in the given repository", __func__);
return nullptr;
}

ss = std::stringstream();
ss << "https://huggingface.co/" << repo << "/resolve/main/" << file_path;
return ss.str();
}

#else
Expand All @@ -2373,11 +2443,9 @@ struct llama_model * llama_load_model_from_url(
return nullptr;
}

struct llama_model * llama_load_model_from_hf(
const char * /*repo*/,
const char * /*model*/,
const char * /*path_model*/,
const struct llama_model_params & /*params*/) {
static std::string llama_get_hf_model_url(
std::string & /*repo*/,
std::string & /*custom_file_path*/) {
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
return nullptr;
}
Expand Down
2 changes: 1 addition & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ struct llama_model_params llama_model_params_from_gpt_params (const gpt_param
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);

struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params);
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params);
static std::string llama_get_hf_model_url(std::string & repo, std::string & custom_file_path);

// Batch utils

Expand Down
Loading