Skip to content

Commit 0e6fa4a

Browse files
committed
sampling : refactor + optimize penalties sampler
ggml-ci
1 parent 8faa1d4 commit 0e6fa4a

File tree

14 files changed

+47
-140
lines changed

14 files changed

+47
-140
lines changed

common/arg.cpp

-7
Original file line numberDiff line numberDiff line change
@@ -826,13 +826,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
826826
params.sampling.ignore_eos = true;
827827
}
828828
).set_sparam());
829-
add_opt(common_arg(
830-
{"--penalize-nl"},
831-
string_format("penalize newline tokens (default: %s)", params.sampling.penalize_nl ? "true" : "false"),
832-
[](common_params & params) {
833-
params.sampling.penalize_nl = true;
834-
}
835-
).set_sparam());
836829
add_opt(common_arg(
837830
{"--temp"}, "N",
838831
string_format("temperature (default: %.1f)", (double)params.sampling.temp),

common/common.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ enum common_sampler_type {
9595
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
9696
COMMON_SAMPLER_TYPE_XTC = 8,
9797
COMMON_SAMPLER_TYPE_INFILL = 9,
98+
COMMON_SAMPLER_TYPE_PENALTIES = 10,
9899
};
99100

100101
// dimensionality reduction methods, used by cvector-generator
@@ -130,7 +131,6 @@ struct common_params_sampling {
130131
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
131132
float mirostat_tau = 5.00f; // target entropy
132133
float mirostat_eta = 0.10f; // learning rate
133-
bool penalize_nl = false; // consider newlines as a repeatable token
134134
bool ignore_eos = false;
135135
bool no_perf = false; // disable performance metrics
136136
bool timing_per_token = false;

common/sampling.cpp

+7-12
Original file line numberDiff line numberDiff line change
@@ -161,18 +161,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
161161
params.logit_bias.size(),
162162
params.logit_bias.data()));
163163

164-
llama_sampler_chain_add(result->chain,
165-
llama_sampler_init_penalties(
166-
llama_n_vocab (model),
167-
llama_token_eos(model),
168-
llama_token_nl (model),
169-
params.penalty_last_n,
170-
params.penalty_repeat,
171-
params.penalty_freq,
172-
params.penalty_present,
173-
params.penalize_nl,
174-
params.ignore_eos));
175-
176164
if (params.mirostat == 0) {
177165
for (const auto & cnstr : params.samplers) {
178166
switch (cnstr) {
@@ -208,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
208196
case COMMON_SAMPLER_TYPE_INFILL:
209197
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
210198
break;
199+
case COMMON_SAMPLER_TYPE_PENALTIES:
200+
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
201+
break;
211202
default:
212203
GGML_ASSERT(false && "unknown sampler type");
213204
}
@@ -415,6 +406,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
415406
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
416407
case COMMON_SAMPLER_TYPE_XTC: return 'x';
417408
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
409+
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
418410
default : return '?';
419411
}
420412
}
@@ -429,6 +421,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
429421
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
430422
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
431423
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
424+
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
432425
default : return "";
433426
}
434427
}
@@ -443,6 +436,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
443436
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
444437
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
445438
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
439+
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
446440
};
447441

448442
// since samplers names are written multiple ways
@@ -489,6 +483,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
489483
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
490484
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
491485
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
486+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
492487
};
493488

494489
std::vector<common_sampler_type> samplers;

examples/batched/batched.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,17 @@ int main(int argc, char ** argv) {
6565
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
6666

6767
auto sparams = llama_sampler_chain_default_params();
68+
sparams.no_perf = false;
6869

6970
llama_sampler * smpl = llama_sampler_chain_init(sparams);
7071

7172
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
73+
llama_sampler_chain_add(smpl,
74+
llama_sampler_init_penalties(
75+
params.sampling.penalty_last_n,
76+
params.sampling.penalty_repeat,
77+
params.sampling.penalty_freq,
78+
params.sampling.penalty_present));
7279
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
7380
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
7481
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));

examples/main/README.md

+1-6
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ If the pause is undesirable, a value of -2 will stop generation immediately when
163163

164164
The `--no-context-shift` option allows you to stop the infinite text generation once the finite context window is full.
165165

166-
It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode, text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `--predict` value. If you want the model to keep going without ever producing End-of-Sequence on its own, you can use the `--ignore-eos` parameter.
166+
It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode, text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `--predict` value.
167167

168168
### Temperature
169169

@@ -177,16 +177,11 @@ Example usage: `--temp 0`
177177

178178
- `--repeat-penalty N`: Control the repetition of token sequences in the generated text default: 1.0, 1.0 = disabled).
179179
- `--repeat-last-n N`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx-size).
180-
- `--no-penalize-nl`: Disable penalization for newline tokens when applying the repeat penalty.
181180

182181
The `repeat-penalty` option helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1.
183182

184183
The `repeat-last-n` option controls the number of tokens in the history to consider for penalizing repetition. A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only consider recent tokens. A value of 0 disables the penalty, and a value of -1 sets the number of tokens considered equal to the context size (`ctx-size`).
185184

186-
Use the `--no-penalize-nl` option to disable newline penalization when applying the repeat penalty. This option is particularly useful for generating chat conversations, dialogues, code, poetry, or any text where newline tokens play a significant role in structure and formatting. Disabling newline penalization helps maintain the natural flow and intended formatting in these specific use cases.
187-
188-
Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl`
189-
190185
### DRY Repetition Penalty
191186

192187
DRY (Don't Repeat Yourself) sampling is an effective technique for reducing repetition in generated text even across long contexts by penalizing tokens based on their recent usage patterns (original [PR link](https://github.com/oobabooga/text-generation-webui/pull/5677)).

examples/server/README.md

-5
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ The project is under active development, and we are [looking for feedback and co
104104
| `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) |
105105
| `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: dkypmxt) |
106106
| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
107-
| `--penalize-nl` | penalize newline tokens (default: false) |
108107
| `--temp N` | temperature (default: 0.8) |
109108
| `--top-k N` | top-k sampling (default: 40, 0 = disabled) |
110109
| `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) |
@@ -392,8 +391,6 @@ These words will not be included in the completion, so make sure to add them to
392391

393392
`repeat_last_n`: Last n tokens to consider for penalizing repetition. Default: `64`, where `0` is disabled and `-1` is ctx-size.
394393

395-
`penalize_nl`: Penalize newline tokens when applying the repeat penalty. Default: `true`
396-
397394
`presence_penalty`: Repeat alpha presence penalty. Default: `0.0`, which is disabled.
398395

399396
`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
@@ -654,7 +651,6 @@ This endpoint is public (no API key check). By default, it is read-only. To make
654651
"mirostat": 0,
655652
"mirostat_tau": 5.0,
656653
"mirostat_eta": 0.10000000149011612,
657-
"penalize_nl": false,
658654
"stop": [],
659655
"max_tokens": -1,
660656
"n_keep": 0,
@@ -844,7 +840,6 @@ Example:
844840
"mirostat": 0,
845841
"mirostat_tau": 5.0,
846842
"mirostat_eta": 0.10000000149011612,
847-
"penalize_nl": false,
848843
"stop": [],
849844
"max_tokens": -1,
850845
"n_keep": 0,

examples/server/public_legacy/index-new.html

-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
temperature: 0.8, // adapt all following parameters to optimized min-p requierements. If for non-english, set to 0.6 or lower
4040
repeat_last_n: 0, // 0 = disable penalty, -1 = context size
4141
repeat_penalty: 1.0, // 1.0 = disabled
42-
penalize_nl: false, // true only useful for infinite completion
4342
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
4443
dry_base: 1.75, // 0.0 = disabled
4544
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well

examples/server/public_legacy/index.html

-2
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,6 @@
303303
temperature: 0.7,
304304
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
305305
repeat_penalty: 1.18, // 1.0 = disabled
306-
penalize_nl: false,
307306
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
308307
dry_base: 1.75, // 0.0 = disabled
309308
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
@@ -1006,7 +1005,6 @@
10061005
${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
10071006
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
10081007
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
1009-
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
10101008
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
10111009
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
10121010
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}

examples/server/server.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ struct slot_params {
135135
{"mirostat", sampling.mirostat},
136136
{"mirostat_tau", sampling.mirostat_tau},
137137
{"mirostat_eta", sampling.mirostat_eta},
138-
{"penalize_nl", sampling.penalize_nl},
139138
{"stop", antiprompt},
140139
{"max_tokens", n_predict}, // User configured n_predict
141140
{"n_keep", n_keep},
@@ -226,7 +225,6 @@ struct server_task {
226225
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
227226
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
228227
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
229-
params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
230228
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
231229
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
232230
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);

examples/server/themes/buttons-top/index.html

-2
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,6 @@
222222
temperature: 0.7,
223223
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
224224
repeat_penalty: 1.18, // 1.0 = disabled
225-
penalize_nl: false,
226225
top_k: 40, // <= 0 to use vocab size
227226
top_p: 0.95, // 1.0 = disabled
228227
min_p: 0.05, // 0 = disabled
@@ -779,7 +778,6 @@
779778
${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
780779
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
781780
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
782-
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
783781
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
784782
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
785783
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}

examples/server/themes/wild/index.html

-2
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@
225225
temperature: 0.7,
226226
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
227227
repeat_penalty: 1.18, // 1.0 = disabled
228-
penalize_nl: false,
229228
top_k: 40, // <= 0 to use vocab size
230229
top_p: 0.95, // 1.0 = disabled
231230
min_p: 0.05, // 0 = disabled
@@ -782,7 +781,6 @@
782781
${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
783782
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
784783
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
785-
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
786784
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
787785
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
788786
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}

include/llama.h

+5-9
Original file line numberDiff line numberDiff line change
@@ -1137,16 +1137,12 @@ extern "C" {
11371137
const char * grammar_str,
11381138
const char * grammar_root);
11391139

1140+
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
11401141
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
1141-
int32_t n_vocab, // llama_n_vocab()
1142-
llama_token special_eos_id, // llama_token_eos()
1143-
llama_token linefeed_id, // llama_token_nl()
1144-
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
1145-
float penalty_repeat, // 1.0 = disabled
1146-
float penalty_freq, // 0.0 = disabled
1147-
float penalty_present, // 0.0 = disabled
1148-
bool penalize_nl, // consider newlines as a repeatable token
1149-
bool ignore_eos); // ignore the end-of-sequence token
1142+
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
1143+
float penalty_repeat, // 1.0 = disabled
1144+
float penalty_freq, // 0.0 = disabled
1145+
float penalty_present); // 0.0 = disabled
11501146

11511147
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
11521148
LLAMA_API struct llama_sampler * llama_sampler_init_dry(

0 commit comments

Comments
 (0)