Skip to content

Commit 64f80c1

Browse files
author
Vaijanath Rao
committed
updating code to support tool's call
1 parent 2a5a1b1 commit 64f80c1

File tree

10 files changed

+2326
-817
lines changed

10 files changed

+2326
-817
lines changed

src/main/cpp/jllama.cpp

+1,682-484
Large diffs are not rendered by default.

src/main/cpp/jllama.h

+131-57
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/main/java/de/kherud/llama/InferenceParameters.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ public InferenceParameters setTools(String... tools) {
564564

565565
parameters.put(PARAM_TOOLS, "[" + toolBuilder.toString() +"]");
566566
parameters.put(PARAM_TOOL_CHOICE, toJsonString("required"));
567-
// parameters.put(PARAM_PARALLEL_TOOL_CALLS,String.valueOf(false));
567+
parameters.put(PARAM_PARALLEL_TOOL_CALLS,String.valueOf(true));
568568
return this;
569569
}
570570

src/main/java/de/kherud/llama/LlamaIterable.java

+7-1
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@
22

33
import org.jetbrains.annotations.NotNull;
44

5+
56
/**
67
* An iterable used by {@link LlamaModel#generate(InferenceParameters)} that specifically returns a {@link LlamaIterator}.
78
*/
89
@FunctionalInterface
910
public interface LlamaIterable extends Iterable<LlamaOutput> {
1011

12+
/**
13+
* Returns a LlamaIterator over elements of type LlamaOutput.
14+
* This overrides the standard iterator() method to specifically return a LlamaIterator.
15+
*
16+
* @return a LlamaIterator instance
17+
*/
1118
@NotNull
1219
@Override
1320
LlamaIterator iterator();
14-
1521
}
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,42 @@
11
package de.kherud.llama;
22

3-
import java.lang.annotation.Native;
3+
44
import java.util.Iterator;
55
import java.util.NoSuchElementException;
66

77
/**
8-
* This iterator is used by {@link LlamaModel#generate(InferenceParameters)}. In addition to implementing {@link Iterator},
9-
* it allows to cancel ongoing inference (see {@link #cancel()}).
8+
* Iterates over a stream of outputs from the model
109
*/
11-
public final class LlamaIterator implements Iterator<LlamaOutput> {
10+
public class LlamaIterator implements Iterator<LlamaOutput> {
1211

1312
private final LlamaModel model;
13+
private final boolean isChat;
1414
private final int taskId;
1515

16-
@Native
17-
@SuppressWarnings("FieldMayBeFinal")
18-
private boolean hasNext = true;
16+
/**
17+
* Whether there is a next token to receive
18+
*/
19+
public boolean hasNext = true;
1920

20-
LlamaIterator(LlamaModel model, InferenceParameters parameters) {
21+
/**
22+
* Creates a new iterator
23+
*
24+
* @param model the llama model to use for generating
25+
* @param parameters parameters for the inference
26+
* @param isChat whether this is a chat completion (true) or regular
27+
* completion (false)
28+
*/
29+
LlamaIterator(LlamaModel model, InferenceParameters parameters, boolean isChat) {
2130
this.model = model;
22-
parameters.setStream(true);
23-
taskId = model.requestCompletion(parameters.toString());
31+
this.isChat = isChat;
32+
33+
if (isChat) {
34+
String prompt = model.applyTemplate(parameters);
35+
parameters.setPrompt(prompt);
36+
this.taskId = model.requestChat(parameters.toString());
37+
} else {
38+
this.taskId = model.requestCompletion(parameters.toString());
39+
}
2440
}
2541

2642
@Override
@@ -33,19 +49,38 @@ public LlamaOutput next() {
3349
if (!hasNext) {
3450
throw new NoSuchElementException();
3551
}
36-
LlamaOutput output = model.receiveCompletion(taskId);
37-
hasNext = !output.stop;
38-
if (output.stop) {
39-
model.releaseTask(taskId);
52+
53+
try {
54+
if (isChat) {
55+
String response = model.streamChatCompletion(taskId);
56+
// Check for completion by examining the JSON response
57+
// This is a simplification - the actual implementation might need more
58+
// sophisticated handling
59+
if (response != null && response.contains("\"finish_reason\":")) {
60+
hasNext = false;
61+
}
62+
return new LlamaOutput(response, !hasNext);
63+
} else {
64+
StreamingOutput output = model.streamCompletion(taskId);
65+
hasNext = !output.isFinal;
66+
return new LlamaOutput(output.text, output.isFinal);
67+
}
68+
} catch (Exception e) {
69+
model.releaseTask(taskId);
70+
hasNext = false;
71+
throw new RuntimeException(e);
4072
}
41-
return output;
4273
}
4374

4475
/**
45-
* Cancel the ongoing generation process.
76+
* Cancel the ongoing generation process. This will stop the model from
77+
* generating more tokens and release resources.
4678
*/
4779
public void cancel() {
48-
model.cancelCompletion(taskId);
49-
hasNext = false;
80+
if (hasNext) {
81+
model.cancelCompletion(taskId);
82+
model.releaseTask(taskId);
83+
hasNext = false;
84+
}
5085
}
51-
}
86+
}

0 commit comments

Comments
 (0)