1
1
package de .kherud .llama ;
2
2
3
- import java . lang . annotation . Native ;
3
+
4
4
import java .util .Iterator ;
5
5
import java .util .NoSuchElementException ;
6
6
7
7
/**
8
- * This iterator is used by {@link LlamaModel#generate(InferenceParameters)}. In addition to implementing {@link Iterator},
9
- * it allows to cancel ongoing inference (see {@link #cancel()}).
8
+ * Iterates over a stream of outputs from the model
10
9
*/
11
- public final class LlamaIterator implements Iterator <LlamaOutput > {
10
+ public class LlamaIterator implements Iterator <LlamaOutput > {
12
11
13
12
private final LlamaModel model ;
13
+ private final boolean isChat ;
14
14
private final int taskId ;
15
15
16
- @ Native
17
- @ SuppressWarnings ("FieldMayBeFinal" )
18
- private boolean hasNext = true ;
16
+ /**
17
+ * Whether there is a next token to receive
18
+ */
19
+ public boolean hasNext = true ;
19
20
20
- LlamaIterator (LlamaModel model , InferenceParameters parameters ) {
21
+ /**
22
+ * Creates a new iterator
23
+ *
24
+ * @param model the llama model to use for generating
25
+ * @param parameters parameters for the inference
26
+ * @param isChat whether this is a chat completion (true) or regular
27
+ * completion (false)
28
+ */
29
+ LlamaIterator (LlamaModel model , InferenceParameters parameters , boolean isChat ) {
21
30
this .model = model ;
22
- parameters .setStream (true );
23
- taskId = model .requestCompletion (parameters .toString ());
31
+ this .isChat = isChat ;
32
+
33
+ if (isChat ) {
34
+ String prompt = model .applyTemplate (parameters );
35
+ parameters .setPrompt (prompt );
36
+ this .taskId = model .requestChat (parameters .toString ());
37
+ } else {
38
+ this .taskId = model .requestCompletion (parameters .toString ());
39
+ }
24
40
}
25
41
26
42
@ Override
@@ -33,19 +49,38 @@ public LlamaOutput next() {
33
49
if (!hasNext ) {
34
50
throw new NoSuchElementException ();
35
51
}
36
- LlamaOutput output = model .receiveCompletion (taskId );
37
- hasNext = !output .stop ;
38
- if (output .stop ) {
39
- model .releaseTask (taskId );
52
+
53
+ try {
54
+ if (isChat ) {
55
+ String response = model .streamChatCompletion (taskId );
56
+ // Check for completion by examining the JSON response
57
+ // This is a simplification - the actual implementation might need more
58
+ // sophisticated handling
59
+ if (response != null && response .contains ("\" finish_reason\" :" )) {
60
+ hasNext = false ;
61
+ }
62
+ return new LlamaOutput (response , !hasNext );
63
+ } else {
64
+ StreamingOutput output = model .streamCompletion (taskId );
65
+ hasNext = !output .isFinal ;
66
+ return new LlamaOutput (output .text , output .isFinal );
67
+ }
68
+ } catch (Exception e ) {
69
+ model .releaseTask (taskId );
70
+ hasNext = false ;
71
+ throw new RuntimeException (e );
40
72
}
41
- return output ;
42
73
}
43
74
44
75
/**
45
- * Cancel the ongoing generation process.
76
+ * Cancel the ongoing generation process. This will stop the model from
77
+ * generating more tokens and release resources.
46
78
*/
47
79
public void cancel () {
48
- model .cancelCompletion (taskId );
49
- hasNext = false ;
80
+ if (hasNext ) {
81
+ model .cancelCompletion (taskId );
82
+ model .releaseTask (taskId );
83
+ hasNext = false ;
84
+ }
50
85
}
51
- }
86
+ }
0 commit comments