Skip to content

When llama_chat_apply_template doesn't work #11687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 37 additions & 24 deletions examples/run/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,37 +837,50 @@ static void add_message(const char * role, const std::string & text, LlamaData &
llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() });
}

// Function to handle Jinja template application
static int handle_jinja_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append) {
json messages = json::array();
for (const auto & msg : llama_data.messages) {
messages.push_back({
{ "role", msg.role },
{ "content", msg.content },
});
}

try {
minja::chat_template_inputs tmpl_inputs;
tmpl_inputs.messages = messages;
tmpl_inputs.add_generation_prompt = append;

minja::chat_template_options tmpl_opts;
tmpl_opts.use_bos_token = false;
tmpl_opts.use_eos_token = false;

auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
llama_data.fmtted.resize(result.size() + 1);
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
return result.size();
} catch (const std::exception & e) {
printe("failed to render the chat template: %s\n", e.what());
}

return -1;
}

// Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
if (use_jinja) {
json messages = json::array();
for (const auto & msg : llama_data.messages) {
messages.push_back({
{"role", msg.role},
{"content", msg.content},
});
}
try {
minja::chat_template_inputs tmpl_inputs;
tmpl_inputs.messages = messages;
tmpl_inputs.add_generation_prompt = append;

minja::chat_template_options tmpl_opts;
tmpl_opts.use_bos_token = false;
tmpl_opts.use_eos_token = false;

auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
llama_data.fmtted.resize(result.size() + 1);
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
return result.size();
} catch (const std::exception & e) {
printe("failed to render the chat template: %s\n", e.what());
return -1;
}
return handle_jinja_template(tmpl, llama_data, append);
}

int result = llama_chat_apply_template(
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
// If llama_chat_apply_template fails to apply template, fallback to using jinja
if (result < 0) {
return handle_jinja_template(tmpl, llama_data, append);
}

if (append && result > static_cast<int>(llama_data.fmtted.size())) {
llama_data.fmtted.resize(result);
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
Expand Down
52 changes: 33 additions & 19 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1895,30 +1895,44 @@ struct server_context {
return true;
}

bool validate_jinja_templates() const {
auto templates = common_chat_templates_from_model(model, "");
common_chat_inputs inputs;
inputs.messages = json::array({
{
{ "role", "user" },
{ "content", "test" },
}
});
GGML_ASSERT(templates.template_default);
try {
common_chat_params_init(*templates.template_default, inputs);
if (templates.template_tool_use) {
common_chat_params_init(*templates.template_tool_use, inputs);
}

return true;
} catch (const std::exception & e) {
SRV_ERR("failed to apply template: %s\n", e.what());

return false;
}
}

bool validate_builtin_chat_template(bool use_jinja) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we choose to fallback to jinja, logically we should remove the bool use_jinja from this function signature, but I'm not sure.

Also have a look on where this function is used:

        if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) {
            LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
            chat_templates = common_chat_templates_from_model(model, "chatml");
        } else {
            chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
        }
        GGML_ASSERT(chat_templates.template_default.get() != nullptr);

So now the LOG_WRN message is no longer valid, probably need to be changed too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the bool still kinda make sense, somebody might want to force use jinja... If the llama_chat_apply_template works, it won't use jinja at all.

We could delete the validation and just fallback on failure. WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, I think we can change the variable name to bool prefer_jinja to make it more intuitive

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, prefer_ or enforce_ prefix would be good.

llama_chat_message chat[] = {{"user", "test"}};
llama_chat_message chat[] = {
{ "user", "test" }
};

if (use_jinja) {
auto templates = common_chat_templates_from_model(model, "");
common_chat_inputs inputs;
inputs.messages = json::array({{
{"role", "user"},
{"content", "test"},
}});
GGML_ASSERT(templates.template_default);
try {
common_chat_params_init(*templates.template_default, inputs);
if (templates.template_tool_use) {
common_chat_params_init(*templates.template_tool_use, inputs);
}
return true;
} catch (const std::exception & e) {
SRV_ERR("failed to apply template: %s\n", e.what());
return false;
}
return validate_jinja_templates();
} else {
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
if (chat_res < 0) {
return validate_jinja_templates();
}

return chat_res > 0;
}
}
Expand Down
Loading