Skip to content

Update server.cpp example with correct startup sequence #6739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 144 additions & 121 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2282,6 +2282,7 @@ struct server_context {
{"size", llama_model_size (model)},
};
}

};

static void server_print_usage(const char * argv0, const gpt_params & params, const server_params & sparams) {
Expand Down Expand Up @@ -3009,42 +3010,58 @@ int main(int argc, char ** argv) {
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
}

// load the model
if (!ctx_server.load_model(params)) {
state.store(SERVER_STATE_ERROR);
return 1;
} else {
ctx_server.init();
state.store(SERVER_STATE_READY);
}
const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) {
server_state current_state = state.load();
switch (current_state) {
case SERVER_STATE_READY:
{
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
task.type = SERVER_TASK_TYPE_METRICS;
task.id_target = -1;

LOG_INFO("model loaded", {});
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);

const auto model_meta = ctx_server.model_meta();
// get the result
server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);

// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
if (sparams.chat_template.empty()) {
if (!ctx_server.validate_model_chat_template()) {
LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
sparams.chat_template = "chatml";
}
}
const int n_idle_slots = result.data["idle"];
const int n_processing_slots = result.data["processing"];

// print sample chat example to make it clear which template is used
{
json chat;
chat.push_back({{"role", "system"}, {"content", "You are a helpful assistant"}});
chat.push_back({{"role", "user"}, {"content", "Hello"}});
chat.push_back({{"role", "assistant"}, {"content", "Hi there"}});
chat.push_back({{"role", "user"}, {"content", "How are you?"}});
json health = {
{"status", "ok"},
{"slots_idle", n_idle_slots},
{"slots_processing", n_processing_slots}
};

const std::string chat_example = format_chat(ctx_server.model, sparams.chat_template, chat);
res.status = 200; // HTTP OK
if (sparams.slots_endpoint && req.has_param("include_slots")) {
health["slots"] = result.data["slots"];
}

LOG_INFO("chat template", {
{"chat_example", chat_example},
{"built_in", sparams.chat_template.empty()},
});
}
if (n_idle_slots == 0) {
health["status"] = "no slot available";
if (req.has_param("fail_on_no_slot")) {
res.status = 503; // HTTP Service Unavailable
}
}

res.set_content(health.dump(), "application/json");
break;
}
case SERVER_STATE_LOADING_MODEL:
{
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
} break;
case SERVER_STATE_ERROR:
{
res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
} break;
}
};

//
// Middlewares
Expand Down Expand Up @@ -3098,71 +3115,43 @@ int main(int argc, char ** argv) {
return false;
};

// register server middlewares
svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
if (!middleware_validate_api_key(req, res)) {
return httplib::Server::HandlerResponse::Handled;
}
return httplib::Server::HandlerResponse::Unhandled;
});
auto middleware_model_loading = [&sparams, &res_error](const httplib::Request & req, httplib::Response & res, server_state current_state) {

//
// Route handlers (or controllers)
//
// If path is not an health check skip validation
if (req.path == "/health" || req.path == "/v1/health") {
return true;
}

const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) {
server_state current_state = state.load();
switch (current_state) {
case SERVER_STATE_READY:
{
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
task.type = SERVER_TASK_TYPE_METRICS;
task.id_target = -1;

ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);

// get the result
server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);

const int n_idle_slots = result.data["idle"];
const int n_processing_slots = result.data["processing"];

json health = {
{"status", "ok"},
{"slots_idle", n_idle_slots},
{"slots_processing", n_processing_slots}
};

res.status = 200; // HTTP OK
if (sparams.slots_endpoint && req.has_param("include_slots")) {
health["slots"] = result.data["slots"];
}

if (n_idle_slots == 0) {
health["status"] = "no slot available";
if (req.has_param("fail_on_no_slot")) {
res.status = 503; // HTTP Service Unavailable
}
}

res.set_content(health.dump(), "application/json");
break;
}
case SERVER_STATE_LOADING_MODEL:
{
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
return false;
} break;
case SERVER_STATE_ERROR:
{
res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
return false;
} break;
default:
break;
}
return true;
};

// register server middlewares
svr->set_pre_routing_handler([&middleware_validate_api_key, &state, &middleware_model_loading](const httplib::Request & req, httplib::Response & res) {
server_state current_state = state.load();
if (!middleware_model_loading(req, res, current_state) || !middleware_validate_api_key(req, res)) {
return httplib::Server::HandlerResponse::Handled;
}
return httplib::Server::HandlerResponse::Unhandled;
});

//
// Route handlers (or controllers)
//

const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
if (!sparams.slots_endpoint) {
res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
Expand Down Expand Up @@ -3482,25 +3471,6 @@ int main(int argc, char ** argv) {
}
};

const auto handle_models = [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));

json models = {
{"object", "list"},
{"data", {
{
{"id", params.model_alias},
{"object", "model"},
{"created", std::time(0)},
{"owned_by", "llamacpp"},
{"meta", model_meta}
},
}}
};

res.set_content(models.dump(), "application/json; charset=utf-8");
};

const auto handle_chat_completions = [&ctx_server, &sparams, &res_error](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
Expand Down Expand Up @@ -3714,33 +3684,31 @@ int main(int argc, char ** argv) {
: responses[0];
return res.set_content(root.dump(), "application/json; charset=utf-8");
};

const auto handle_models = [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));

auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
return false;
};
};

//
// Router
//
json model_meta = ctx_server.model_meta();

// register static assets routes
if (!sparams.public_path.empty()) {
// Set the base directory for serving static files
svr->set_base_dir(sparams.public_path);
}
json models = {
{"object", "list"},
{"data", {
{
{"id", params.model_alias},
{"object", "model"},
{"created", std::time(0)},
{"owned_by", "llamacpp"},
{"meta", model_meta}
},
}}
};

// using embedded static files
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
svr->Get("/json-schema-to-grammar.mjs", handle_static_file(
json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
res.set_content(models.dump(), "application/json; charset=utf-8");
};

// register API routes
svr->Get ("/health", handle_health);
svr->Get ("/v1/health", handle_health);
svr->Get ("/slots", handle_slots);
svr->Get ("/metrics", handle_metrics);
svr->Get ("/props", handle_props);
Expand All @@ -3756,11 +3724,31 @@ int main(int argc, char ** argv) {
svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/tokenize", handle_tokenize);
svr->Post("/detokenize", handle_detokenize);

if (!sparams.slot_save_path.empty()) {
// only enable slot endpoints if slot_save_path is set
svr->Post("/slots/:id_slot", handle_slots_action);
}

auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
return false;
};
};

if (!sparams.public_path.empty()) {
// Set the base directory for serving static files
svr->set_base_dir(sparams.public_path);
}

// using embedded static files
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
svr->Get("/json-schema-to-grammar.mjs", handle_static_file(
json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));

//
// Start the server
//
Expand All @@ -3782,6 +3770,41 @@ int main(int argc, char ** argv) {

return 0;
});

// load the model
if (!ctx_server.load_model(params)) {
state.store(SERVER_STATE_ERROR);
return 1;
} else {
ctx_server.init();
state.store(SERVER_STATE_READY);
}

LOG_INFO("model loaded", {});

// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
if (sparams.chat_template.empty()) {
if (!ctx_server.validate_model_chat_template()) {
LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
sparams.chat_template = "chatml";
}
}

// print sample chat example to make it clear which template is used
{
json chat;
chat.push_back({{"role", "system"}, {"content", "You are a helpful assistant"}});
chat.push_back({{"role", "user"}, {"content", "Hello"}});
chat.push_back({{"role", "assistant"}, {"content", "Hi there"}});
chat.push_back({{"role", "user"}, {"content", "How are you?"}});

const std::string chat_example = format_chat(ctx_server.model, sparams.chat_template, chat);

LOG_INFO("chat template", {
{"chat_example", chat_example},
{"built_in", sparams.chat_template.empty()},
});
}

ctx_server.queue_tasks.on_new_task(std::bind(
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
Expand Down