|
17 | 17 | package java.com.google.firebase.ai;
|
18 | 18 |
|
19 | 19 | import android.graphics.Bitmap;
|
| 20 | +import androidx.annotation.Nullable; |
20 | 21 | import com.google.common.util.concurrent.ListenableFuture;
|
21 | 22 | import com.google.firebase.ai.FirebaseAI;
|
22 | 23 | import com.google.firebase.ai.GenerativeModel;
|
| 24 | +import com.google.firebase.ai.LiveGenerativeModel; |
23 | 25 | import com.google.firebase.ai.java.ChatFutures;
|
24 | 26 | import com.google.firebase.ai.java.GenerativeModelFutures;
|
| 27 | +import com.google.firebase.ai.java.LiveModelFutures; |
| 28 | +import com.google.firebase.ai.java.LiveSessionFutures; |
25 | 29 | import com.google.firebase.ai.type.BlockReason;
|
26 | 30 | import com.google.firebase.ai.type.Candidate;
|
27 | 31 | import com.google.firebase.ai.type.Citation;
|
|
32 | 36 | import com.google.firebase.ai.type.FileDataPart;
|
33 | 37 | import com.google.firebase.ai.type.FinishReason;
|
34 | 38 | import com.google.firebase.ai.type.FunctionCallPart;
|
| 39 | +import com.google.firebase.ai.type.FunctionResponsePart; |
35 | 40 | import com.google.firebase.ai.type.GenerateContentResponse;
|
| 41 | +import com.google.firebase.ai.type.GenerationConfig; |
36 | 42 | import com.google.firebase.ai.type.HarmCategory;
|
37 | 43 | import com.google.firebase.ai.type.HarmProbability;
|
38 | 44 | import com.google.firebase.ai.type.HarmSeverity;
|
39 | 45 | import com.google.firebase.ai.type.ImagePart;
|
40 | 46 | import com.google.firebase.ai.type.InlineDataPart;
|
| 47 | +import com.google.firebase.ai.type.LiveGenerationConfig; |
| 48 | +import com.google.firebase.ai.type.LiveServerContent; |
| 49 | +import com.google.firebase.ai.type.LiveServerMessage; |
| 50 | +import com.google.firebase.ai.type.LiveServerSetupComplete; |
| 51 | +import com.google.firebase.ai.type.LiveServerToolCall; |
| 52 | +import com.google.firebase.ai.type.LiveServerToolCallCancellation; |
| 53 | +import com.google.firebase.ai.type.MediaData; |
41 | 54 | import com.google.firebase.ai.type.ModalityTokenCount;
|
42 | 55 | import com.google.firebase.ai.type.Part;
|
43 | 56 | import com.google.firebase.ai.type.PromptFeedback;
|
| 57 | +import com.google.firebase.ai.type.PublicPreviewAPI; |
| 58 | +import com.google.firebase.ai.type.ResponseModality; |
44 | 59 | import com.google.firebase.ai.type.SafetyRating;
|
| 60 | +import com.google.firebase.ai.type.SpeechConfig; |
45 | 61 | import com.google.firebase.ai.type.TextPart;
|
46 | 62 | import com.google.firebase.ai.type.UsageMetadata;
|
| 63 | +import com.google.firebase.ai.type.Voices; |
47 | 64 | import com.google.firebase.concurrent.FirebaseExecutors;
|
48 | 65 | import java.util.Calendar;
|
49 | 66 | import java.util.List;
|
50 | 67 | import java.util.Map;
|
51 | 68 | import java.util.concurrent.Executor;
|
| 69 | +import kotlin.OptIn; |
52 | 70 | import kotlinx.serialization.json.JsonElement;
|
53 | 71 | import kotlinx.serialization.json.JsonNull;
|
| 72 | +import kotlinx.serialization.json.JsonObject; |
54 | 73 | import org.junit.Assert;
|
55 | 74 | import org.reactivestreams.Publisher;
|
56 | 75 | import org.reactivestreams.Subscriber;
|
|
59 | 78 | /**
|
60 | 79 | * Tests in this file exist to be compiled, not invoked
|
61 | 80 | */
|
| 81 | +@OptIn(markerClass = PublicPreviewAPI.class) |
62 | 82 | public class JavaCompileTests {
|
63 | 83 |
|
64 | 84 | public void initializeJava() throws Exception {
|
65 | 85 | FirebaseAI vertex = FirebaseAI.getInstance();
|
66 |
| - GenerativeModel model = vertex.generativeModel("fake-model-name"); |
| 86 | + GenerativeModel model = vertex.generativeModel("fake-model-name", getConfig()); |
| 87 | + LiveGenerativeModel live = vertex.liveModel("fake-model-name", getLiveConfig()); |
67 | 88 | GenerativeModelFutures futures = GenerativeModelFutures.from(model);
|
| 89 | + LiveModelFutures liveFutures = LiveModelFutures.from(live); |
68 | 90 | testFutures(futures);
|
| 91 | + testLiveFutures(liveFutures); |
| 92 | + } |
| 93 | + |
| 94 | + private GenerationConfig getConfig() { |
| 95 | + return new GenerationConfig.Builder().build(); |
| 96 | + // TODO b/406558430 GenerationConfig.Builder.setParts returns void |
| 97 | + } |
| 98 | + |
| 99 | + private LiveGenerationConfig getLiveConfig() { |
| 100 | + return new LiveGenerationConfig.Builder() |
| 101 | + .setTopK(10) |
| 102 | + .setTopP(11.0F) |
| 103 | + .setTemperature(32.0F) |
| 104 | + .setCandidateCount(1) |
| 105 | + .setMaxOutputTokens(0xCAFEBABE) |
| 106 | + .setFrequencyPenalty(1.0F) |
| 107 | + .setPresencePenalty(2.0F) |
| 108 | + .setResponseModality(ResponseModality.AUDIO) |
| 109 | + .setSpeechConfig(new SpeechConfig(Voices.AOEDE)) |
| 110 | + .build(); |
69 | 111 | }
|
70 | 112 |
|
71 | 113 | private void testFutures(GenerativeModelFutures futures) throws Exception {
|
@@ -159,7 +201,10 @@ public void validateCandidates(List<Candidate> candidates) {
|
159 | 201 | }
|
160 | 202 | }
|
161 | 203 |
|
162 |
| - public void validateContent(Content content) { |
| 204 | + public void validateContent(@Nullable Content content) { |
| 205 | + if (content == null) { |
| 206 | + return; |
| 207 | + } |
163 | 208 | String role = content.getRole();
|
164 | 209 | for (Part part : content.getParts()) {
|
165 | 210 | if (part instanceof TextPart) {
|
@@ -236,4 +281,67 @@ public void validateUsageMetadata(UsageMetadata metadata) {
|
236 | 281 | }
|
237 | 282 | }
|
238 | 283 | }
|
| 284 | + |
| 285 | + private void testLiveFutures(LiveModelFutures futures) throws Exception { |
| 286 | + LiveSessionFutures session = futures.connect().get(); |
| 287 | + session |
| 288 | + .receive() |
| 289 | + .subscribe( |
| 290 | + new Subscriber<LiveServerMessage>() { |
| 291 | + @Override |
| 292 | + public void onSubscribe(Subscription s) { |
| 293 | + s.request(Long.MAX_VALUE); |
| 294 | + } |
| 295 | + |
| 296 | + @Override |
| 297 | + public void onNext(LiveServerMessage message) { |
| 298 | + validateLiveContentResponse(message); |
| 299 | + } |
| 300 | + |
| 301 | + @Override |
| 302 | + public void onError(Throwable t) { |
| 303 | + // Ignore |
| 304 | + } |
| 305 | + |
| 306 | + @Override |
| 307 | + public void onComplete() { |
| 308 | + // Also ignore |
| 309 | + } |
| 310 | + }); |
| 311 | + |
| 312 | + session.send("Fake message"); |
| 313 | + session.send(new Content.Builder().addText("Fake message").build()); |
| 314 | + |
| 315 | + byte[] bytes = new byte[] {(byte) 0xCA, (byte) 0xFE, (byte) 0xBA, (byte) 0xBE}; |
| 316 | + session.sendMediaStream(List.of(new MediaData(bytes, "image/jxl"))); |
| 317 | + |
| 318 | + FunctionResponsePart functionResponse = |
| 319 | + new FunctionResponsePart("myFunction", new JsonObject(Map.of())); |
| 320 | + session.sendFunctionResponse(List.of(functionResponse, functionResponse)); |
| 321 | + |
| 322 | + session.startAudioConversation(part -> functionResponse); |
| 323 | + session.startAudioConversation(); |
| 324 | + session.stopAudioConversation(); |
| 325 | + session.stopReceiving(); |
| 326 | + session.close(); |
| 327 | + } |
| 328 | + |
| 329 | + private void validateLiveContentResponse(LiveServerMessage message) { |
| 330 | + if (message instanceof LiveServerContent) { |
| 331 | + LiveServerContent content = (LiveServerContent) message; |
| 332 | + validateContent(content.getContent()); |
| 333 | + boolean complete = content.getGenerationComplete(); |
| 334 | + boolean interrupted = content.getInterrupted(); |
| 335 | + boolean turnComplete = content.getTurnComplete(); |
| 336 | + } else if (message instanceof LiveServerSetupComplete) { |
| 337 | + LiveServerSetupComplete setup = (LiveServerSetupComplete) message; |
| 338 | + // No methods |
| 339 | + } else if (message instanceof LiveServerToolCall) { |
| 340 | + LiveServerToolCall call = (LiveServerToolCall) message; |
| 341 | + validateFunctionCalls(call.getFunctionCalls()); |
| 342 | + } else if (message instanceof LiveServerToolCallCancellation) { |
| 343 | + LiveServerToolCallCancellation cancel = (LiveServerToolCallCancellation) message; |
| 344 | + List<String> functions = cancel.getFunctionIds(); |
| 345 | + } |
| 346 | + } |
239 | 347 | }
|
0 commit comments