Skip to content

Commit 964f497

Browse files
authored
Merge pull request #125 from martindevans/native_sampling_api_improvements
Removed unnecessary parameters from some low level sampler methods
2 parents 0ce5cf9 + 826c6aa commit 964f497

File tree

5 files changed

+107
-24
lines changed

5 files changed

+107
-24
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System.Collections.Generic;
2+
3+
namespace LLama.Extensions
4+
{
5+
internal static class DictionaryExtensions
6+
{
7+
#if NETSTANDARD2_0
8+
public static TValue GetValueOrDefault<TKey, TValue>(this IReadOnlyDictionary<TKey, TValue> dictionary, TKey key, TValue defaultValue)
9+
{
10+
return dictionary.TryGetValue(key, out var value) ? value : defaultValue;
11+
}
12+
#endif
13+
}
14+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using System.Collections.Generic;
2+
using System.Linq;
3+
4+
namespace LLama.Extensions
5+
{
6+
internal static class IEnumerableExtensions
7+
{
8+
#if NETSTANDARD2_0
9+
public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> source, int count)
10+
{
11+
var list = source.ToList();
12+
13+
if (count >= list.Count)
14+
return list;
15+
16+
list.RemoveRange(0, list.Count - count);
17+
return list;
18+
}
19+
#endif
20+
}
21+
}

LLama/LLamaContext.cs

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -355,36 +355,41 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable<llama_token> lastTokens, Dic
355355
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
356356
bool penalizeNL = true)
357357
{
358-
var n_vocab = _ctx.VocabCount;
359358
var logits = _ctx.GetLogits();
360359

361360
// Apply params.logit_bias map
362-
if(logitBias is not null)
361+
if (logitBias is not null)
363362
{
364363
foreach (var (key, value) in logitBias)
365-
{
366364
logits[key] += value;
367-
}
368365
}
369366

370-
var candidates = new LLamaTokenData[n_vocab];
371-
for (llama_token token_id = 0; token_id < n_vocab; token_id++)
372-
candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f);
373-
LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates);
374-
375-
// Apply penalties
376-
float nl_logit = logits[NativeApi.llama_token_nl()];
377-
int lastTokensCount = lastTokens.Count();
378-
var last_n_repeat = Math.Min(Math.Min(lastTokensCount, repeatLastTokensCount), ContextSize);
379-
SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p,
380-
lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(),
381-
(ulong)last_n_repeat, repeatPenalty);
382-
SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p,
383-
lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(),
384-
(ulong)last_n_repeat, alphaFrequency, alphaPresence);
367+
// Save the newline logit value
368+
var nl_token = NativeApi.llama_token_nl();
369+
var nl_logit = logits[nl_token];
370+
371+
// Convert logits into token candidates
372+
var candidates_p = LLamaTokenDataArray.Create(logits);
373+
374+
// Extract most recently returned tokens
375+
var last_n_repeat = Math.Min(ContextSize, repeatLastTokensCount);
376+
var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray();
377+
378+
// Apply penalties to candidates
379+
SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, last_n_array, repeatPenalty);
380+
SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p, last_n_array, alphaFrequency, alphaPresence);
381+
382+
// Restore newline token logit value if necessary
385383
if (!penalizeNL)
386384
{
387-
logits[NativeApi.llama_token_nl()] = nl_logit;
385+
var candidatesSpan = candidates_p.data.Span;
386+
for (var i = 0; i < candidates_p.data.Length; i++)
387+
{
388+
ref var item = ref candidatesSpan[i];
389+
if (item.id == nl_token)
390+
item.logit = nl_logit;
391+
}
392+
candidates_p.sorted = false;
388393
}
389394

390395
return candidates_p;

LLama/Native/LLamaTokenDataArray.cs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
using System.Buffers;
33
using System.Runtime.InteropServices;
44

5+
using llama_token = System.Int32;
6+
57
namespace LLama.Native
68
{
79
/// <summary>
@@ -15,9 +17,9 @@ public struct LLamaTokenDataArray
1517
public readonly Memory<LLamaTokenData> data;
1618

1719
/// <summary>
18-
/// Indicates if `data` is sorted
20+
/// Indicates if `data` is sorted by logits in descending order. If this is false the token data is in _no particular order_.
1921
/// </summary>
20-
public readonly bool sorted;
22+
public bool sorted;
2123

2224
/// <summary>
2325
/// Create a new LLamaTokenDataArray
@@ -29,6 +31,20 @@ public LLamaTokenDataArray(Memory<LLamaTokenData> tokens, bool isSorted = false)
2931
data = tokens;
3032
sorted = isSorted;
3133
}
34+
35+
/// <summary>
36+
/// Create a new LLamaTokenDataArray, copying the data from the given logits
37+
/// </summary>
38+
/// <param name="logits"></param>
39+
/// <returns></returns>
40+
public static LLamaTokenDataArray Create(ReadOnlySpan<float> logits)
41+
{
42+
var candidates = new LLamaTokenData[logits.Length];
43+
for (var token_id = 0; token_id < logits.Length; token_id++)
44+
candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f);
45+
46+
return new LLamaTokenDataArray(candidates);
47+
}
3248
}
3349

3450
/// <summary>

LLama/Native/SamplingApi.cs

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,25 @@ public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDa
2525
/// <param name="last_tokens"></param>
2626
/// <param name="last_tokens_size"></param>
2727
/// <param name="penalty"></param>
28+
[Obsolete("last_tokens_size parameter is no longer needed")]
2829
public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float penalty)
30+
{
31+
llama_sample_repetition_penalty(ctx, candidates, last_tokens, penalty);
32+
}
33+
34+
/// <summary>
35+
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
36+
/// </summary>
37+
/// <param name="ctx"></param>
38+
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
39+
/// <param name="last_tokens"></param>
40+
/// <param name="penalty"></param>
41+
public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float penalty)
2942
{
3043
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
3144
using var last_tokens_handle = last_tokens.Pin();
3245

33-
NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, last_tokens_size, penalty);
46+
NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty);
3447
}
3548

3649
/// <summary>
@@ -42,12 +55,26 @@ public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, L
4255
/// <param name="last_tokens_size"></param>
4356
/// <param name="alpha_frequency"></param>
4457
/// <param name="alpha_presence"></param>
58+
[Obsolete("last_tokens_size parameter is no longer needed")]
4559
public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence)
60+
{
61+
llama_sample_frequency_and_presence_penalties(ctx, candidates, last_tokens, alpha_frequency, alpha_presence);
62+
}
63+
64+
/// <summary>
65+
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
66+
/// </summary>
67+
/// <param name="ctx"></param>
68+
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
69+
/// <param name="last_tokens"></param>
70+
/// <param name="alpha_frequency"></param>
71+
/// <param name="alpha_presence"></param>
72+
public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float alpha_frequency, float alpha_presence)
4673
{
4774
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
4875
using var last_tokens_handle = last_tokens.Pin();
4976

50-
NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, last_tokens_size, alpha_frequency, alpha_presence);
77+
NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, alpha_frequency, alpha_presence);
5178
}
5279

5380
/// <summary>

0 commit comments

Comments
 (0)