Skip to content

chat: improve llama 3.x handling of <|python_tag|> (+ allow --special combo) #13932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/chat-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::
result_.tool_calls.emplace_back(tool_call);
return true;
}
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
bool common_chat_msg_parser::add_tool_call(const json & tool_call, const char * arguments_name) {
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
std::string arguments = tool_call.contains(arguments_name) ? tool_call.at(arguments_name) : "";
return add_tool_call(name, id, arguments);
}

Expand Down
2 changes: 1 addition & 1 deletion common/chat-parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class common_chat_msg_parser {
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);

// Adds a tool call using the "name", "id" and "arguments" fields of the json object
bool add_tool_call(const nlohmann::ordered_json & tool_call);
bool add_tool_call(const nlohmann::ordered_json & tool_call, const char * arguments_name = "arguments");

// Adds an array of tool calls using their "name", "id" and "arguments" fields.
bool add_tool_calls(const nlohmann::ordered_json & arr);
Expand Down
50 changes: 30 additions & 20 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"<|python_tag|>\"? space "
"\"{\" space "
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
Expand All @@ -1105,12 +1106,12 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
"(\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\")[\\s\\S]*", // + name + "\"[\\s\\S]*",
"((?:<\\|python_tag\\|>\\s*)?\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\")[\\s\\S]*", // + name + "\"[\\s\\S]*",
});
if (!builtin_tools.empty()) {
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
data.preserved_tokens.push_back("<|python_tag|>");
}
data.preserved_tokens.push_back("<|python_tag|>");
// Allow a few empty lines on top of the usual constrained json schema space rule.
builder.add_rule("root", string_join(tool_rules, " | "));
data.additional_stops.push_back("<|eom_id|>");
Expand All @@ -1134,16 +1135,18 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
return;
}

static const common_regex function_regex(
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
static const common_regex close_regex("\\}\\s*");
static const common_regex python_tag_regex("\\s*<\\|python_tag\\|>");

static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\(");
static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*");
auto initial_pos = builder.pos();
if (auto res = builder.try_consume_regex(python_tag_regex)) {
if (auto tc = builder.try_consume_json_with_dumped_args({{"parameters"}})) {
if (!builder.add_tool_call(tc->value, "parameters")) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
} else if (with_builtin_tools) {
static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\(");
static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*");

if (with_builtin_tools) {
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
if (auto res = builder.try_find_regex(builtin_call_regex)) {
auto fun_res = builder.consume_regex(function_name_regex);
auto function_name = builder.str(fun_res.groups[1]);

Expand Down Expand Up @@ -1171,17 +1174,24 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
if (!builder.add_tool_call(function_name, "", arguments)) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
return;
}
} else if (auto tc = builder.try_consume_json_with_dumped_args({{"parameters"}})) {
if (!builder.add_tool_call(tc->value, "parameters")) {
auto has_unknown_keys = false;
for (const auto & [key, value] : tc->value.items()) {
if (key != "parameters" && key != "name" && key != "type") {
has_unknown_keys = true;
break;
}
}
if (has_unknown_keys) {
builder.move_to(initial_pos);
} else {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
}
parse_json_tool_calls(
builder,
/* block_open= */ std::nullopt,
/* function_regex_start_only= */ function_regex,
/* function_regex= */ std::nullopt,
close_regex,
std::nullopt);

builder.add_content(builder.consume_rest());
}

static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
Expand Down Expand Up @@ -1453,8 +1463,8 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
if (has_raw_python) {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
data.preserved_tokens.push_back("<|python_tag|>");
}
data.preserved_tokens.push_back("<|python_tag|>");
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
Expand Down
30 changes: 30 additions & 0 deletions tests/test-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,36 @@ static void test_template_output_parsers() {
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_LLAMA_3_X}));
assert_equals(
message_assist_call,
common_chat_parse(
"<|python_tag|>{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_LLAMA_3_X}));
assert_equals(
message_assist_call,
common_chat_parse(
"<|python_tag|>{\"type\": \"function\", \"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_LLAMA_3_X}));
assert_equals(
simple_assist_msg("{\"something\": \"else\"}"),
common_chat_parse(
"{\"something\": \"else\"}",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_LLAMA_3_X}));
assert_equals(
message_assist_empty,
common_chat_parse(
"{\"some",
/* is_partial= */ true,
{COMMON_CHAT_FORMAT_LLAMA_3_X}));
assert_equals(
message_assist_empty,
common_chat_parse(
"{\"parameters\": {\"arg1\": 1}",
/* is_partial= */ true,
{COMMON_CHAT_FORMAT_LLAMA_3_X}));

// test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools,
Expand Down
Loading