Skip to content

Commit 819f05f

Browse files
authored
Merge 6b9352e into 2b23887
2 parents 2b23887 + 6b9352e commit 819f05f

File tree

2 files changed

+203
-3
lines changed

2 files changed

+203
-3
lines changed

firebase-ai/src/testUtil/java/com/google/firebase/ai/JavaCompileTests.java

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
package java.com.google.firebase.ai;
1818

1919
import android.graphics.Bitmap;
20+
import androidx.annotation.Nullable;
2021
import com.google.common.util.concurrent.ListenableFuture;
2122
import com.google.firebase.ai.FirebaseAI;
2223
import com.google.firebase.ai.GenerativeModel;
24+
import com.google.firebase.ai.LiveGenerativeModel;
2325
import com.google.firebase.ai.java.ChatFutures;
2426
import com.google.firebase.ai.java.GenerativeModelFutures;
27+
import com.google.firebase.ai.java.LiveModelFutures;
28+
import com.google.firebase.ai.java.LiveSessionFutures;
2529
import com.google.firebase.ai.type.BlockReason;
2630
import com.google.firebase.ai.type.Candidate;
2731
import com.google.firebase.ai.type.Citation;
@@ -32,25 +36,40 @@
3236
import com.google.firebase.ai.type.FileDataPart;
3337
import com.google.firebase.ai.type.FinishReason;
3438
import com.google.firebase.ai.type.FunctionCallPart;
39+
import com.google.firebase.ai.type.FunctionResponsePart;
3540
import com.google.firebase.ai.type.GenerateContentResponse;
41+
import com.google.firebase.ai.type.GenerationConfig;
3642
import com.google.firebase.ai.type.HarmCategory;
3743
import com.google.firebase.ai.type.HarmProbability;
3844
import com.google.firebase.ai.type.HarmSeverity;
3945
import com.google.firebase.ai.type.ImagePart;
4046
import com.google.firebase.ai.type.InlineDataPart;
47+
import com.google.firebase.ai.type.LiveGenerationConfig;
48+
import com.google.firebase.ai.type.LiveServerContent;
49+
import com.google.firebase.ai.type.LiveServerMessage;
50+
import com.google.firebase.ai.type.LiveServerSetupComplete;
51+
import com.google.firebase.ai.type.LiveServerToolCall;
52+
import com.google.firebase.ai.type.LiveServerToolCallCancellation;
53+
import com.google.firebase.ai.type.MediaData;
4154
import com.google.firebase.ai.type.ModalityTokenCount;
4255
import com.google.firebase.ai.type.Part;
4356
import com.google.firebase.ai.type.PromptFeedback;
57+
import com.google.firebase.ai.type.PublicPreviewAPI;
58+
import com.google.firebase.ai.type.ResponseModality;
4459
import com.google.firebase.ai.type.SafetyRating;
60+
import com.google.firebase.ai.type.SpeechConfig;
4561
import com.google.firebase.ai.type.TextPart;
4662
import com.google.firebase.ai.type.UsageMetadata;
63+
import com.google.firebase.ai.type.Voices;
4764
import com.google.firebase.concurrent.FirebaseExecutors;
4865
import java.util.Calendar;
4966
import java.util.List;
5067
import java.util.Map;
5168
import java.util.concurrent.Executor;
69+
import kotlin.OptIn;
5270
import kotlinx.serialization.json.JsonElement;
5371
import kotlinx.serialization.json.JsonNull;
72+
import kotlinx.serialization.json.JsonObject;
5473
import org.junit.Assert;
5574
import org.reactivestreams.Publisher;
5675
import org.reactivestreams.Subscriber;
@@ -59,13 +78,36 @@
5978
/**
6079
* Tests in this file exist to be compiled, not invoked
6180
*/
81+
@OptIn(markerClass = PublicPreviewAPI.class)
6282
public class JavaCompileTests {
6383

6484
public void initializeJava() throws Exception {
6585
FirebaseAI vertex = FirebaseAI.getInstance();
66-
GenerativeModel model = vertex.generativeModel("fake-model-name");
86+
GenerativeModel model = vertex.generativeModel("fake-model-name", getConfig());
87+
LiveGenerativeModel live = vertex.liveModel("fake-model-name", getLiveConfig());
6788
GenerativeModelFutures futures = GenerativeModelFutures.from(model);
89+
LiveModelFutures liveFutures = LiveModelFutures.from(live);
6890
testFutures(futures);
91+
testLiveFutures(liveFutures);
92+
}
93+
94+
private GenerationConfig getConfig() {
95+
return new GenerationConfig.Builder().build();
96+
// TODO b/406558430 GenerationConfig.Builder.setParts returns void
97+
}
98+
99+
private LiveGenerationConfig getLiveConfig() {
100+
return new LiveGenerationConfig.Builder()
101+
.setTopK(10)
102+
.setTopP(11.0F)
103+
.setTemperature(32.0F)
104+
.setCandidateCount(1)
105+
.setMaxOutputTokens(0xCAFEBABE)
106+
.setFrequencyPenalty(1.0F)
107+
.setPresencePenalty(2.0F)
108+
.setResponseModality(ResponseModality.AUDIO)
109+
.setSpeechConfig(new SpeechConfig(Voices.AOEDE))
110+
.build();
69111
}
70112

71113
private void testFutures(GenerativeModelFutures futures) throws Exception {
@@ -159,7 +201,10 @@ public void validateCandidates(List<Candidate> candidates) {
159201
}
160202
}
161203

162-
public void validateContent(Content content) {
204+
public void validateContent(@Nullable Content content) {
205+
if (content == null) {
206+
return;
207+
}
163208
String role = content.getRole();
164209
for (Part part : content.getParts()) {
165210
if (part instanceof TextPart) {
@@ -236,4 +281,67 @@ public void validateUsageMetadata(UsageMetadata metadata) {
236281
}
237282
}
238283
}
284+
285+
private void testLiveFutures(LiveModelFutures futures) throws Exception {
286+
LiveSessionFutures session = futures.connect().get();
287+
session
288+
.receive()
289+
.subscribe(
290+
new Subscriber<LiveServerMessage>() {
291+
@Override
292+
public void onSubscribe(Subscription s) {
293+
s.request(Long.MAX_VALUE);
294+
}
295+
296+
@Override
297+
public void onNext(LiveServerMessage message) {
298+
validateLiveContentResponse(message);
299+
}
300+
301+
@Override
302+
public void onError(Throwable t) {
303+
// Ignore
304+
}
305+
306+
@Override
307+
public void onComplete() {
308+
// Also ignore
309+
}
310+
});
311+
312+
session.send("Fake message");
313+
session.send(new Content.Builder().addText("Fake message").build());
314+
315+
byte[] bytes = new byte[] {(byte) 0xCA, (byte) 0xFE, (byte) 0xBA, (byte) 0xBE};
316+
session.sendMediaStream(List.of(new MediaData(bytes, "image/jxl")));
317+
318+
FunctionResponsePart functionResponse =
319+
new FunctionResponsePart("myFunction", new JsonObject(Map.of()));
320+
session.sendFunctionResponse(List.of(functionResponse, functionResponse));
321+
322+
session.startAudioConversation(part -> functionResponse);
323+
session.startAudioConversation();
324+
session.stopAudioConversation();
325+
session.stopReceiving();
326+
session.close();
327+
}
328+
329+
private void validateLiveContentResponse(LiveServerMessage message) {
330+
if (message instanceof LiveServerContent) {
331+
LiveServerContent content = (LiveServerContent) message;
332+
validateContent(content.getContent());
333+
boolean complete = content.getGenerationComplete();
334+
boolean interrupted = content.getInterrupted();
335+
boolean turnComplete = content.getTurnComplete();
336+
} else if (message instanceof LiveServerSetupComplete) {
337+
LiveServerSetupComplete setup = (LiveServerSetupComplete) message;
338+
// No methods
339+
} else if (message instanceof LiveServerToolCall) {
340+
LiveServerToolCall call = (LiveServerToolCall) message;
341+
validateFunctionCalls(call.getFunctionCalls());
342+
} else if (message instanceof LiveServerToolCallCancellation) {
343+
LiveServerToolCallCancellation cancel = (LiveServerToolCallCancellation) message;
344+
List<String> functions = cancel.getFunctionIds();
345+
}
346+
}
239347
}

firebase-vertexai/src/testUtil/java/com/google/firebase/vertexai/JavaCompileTests.java

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@
2121
import com.google.firebase.concurrent.FirebaseExecutors;
2222
import com.google.firebase.vertexai.FirebaseVertexAI;
2323
import com.google.firebase.vertexai.GenerativeModel;
24+
import com.google.firebase.vertexai.LiveGenerativeModel;
2425
import com.google.firebase.vertexai.java.ChatFutures;
2526
import com.google.firebase.vertexai.java.GenerativeModelFutures;
27+
import com.google.firebase.vertexai.java.LiveModelFutures;
28+
import com.google.firebase.vertexai.java.LiveSessionFutures;
2629
import com.google.firebase.vertexai.type.BlockReason;
2730
import com.google.firebase.vertexai.type.Candidate;
2831
import com.google.firebase.vertexai.type.Citation;
@@ -33,24 +36,33 @@
3336
import com.google.firebase.vertexai.type.FileDataPart;
3437
import com.google.firebase.vertexai.type.FinishReason;
3538
import com.google.firebase.vertexai.type.FunctionCallPart;
39+
import com.google.firebase.vertexai.type.FunctionResponsePart;
3640
import com.google.firebase.vertexai.type.GenerateContentResponse;
41+
import com.google.firebase.vertexai.type.GenerationConfig;
3742
import com.google.firebase.vertexai.type.HarmCategory;
3843
import com.google.firebase.vertexai.type.HarmProbability;
3944
import com.google.firebase.vertexai.type.HarmSeverity;
4045
import com.google.firebase.vertexai.type.ImagePart;
4146
import com.google.firebase.vertexai.type.InlineDataPart;
47+
import com.google.firebase.vertexai.type.LiveContentResponse;
48+
import com.google.firebase.vertexai.type.LiveGenerationConfig;
49+
import com.google.firebase.vertexai.type.MediaData;
4250
import com.google.firebase.vertexai.type.ModalityTokenCount;
4351
import com.google.firebase.vertexai.type.Part;
4452
import com.google.firebase.vertexai.type.PromptFeedback;
53+
import com.google.firebase.vertexai.type.ResponseModality;
4554
import com.google.firebase.vertexai.type.SafetyRating;
55+
import com.google.firebase.vertexai.type.SpeechConfig;
4656
import com.google.firebase.vertexai.type.TextPart;
4757
import com.google.firebase.vertexai.type.UsageMetadata;
58+
import com.google.firebase.vertexai.type.Voices;
4859
import java.util.Calendar;
4960
import java.util.List;
5061
import java.util.Map;
5162
import java.util.concurrent.Executor;
5263
import kotlinx.serialization.json.JsonElement;
5364
import kotlinx.serialization.json.JsonNull;
65+
import kotlinx.serialization.json.JsonObject;
5466
import org.junit.Assert;
5567
import org.reactivestreams.Publisher;
5668
import org.reactivestreams.Subscriber;
@@ -63,9 +75,31 @@ public class JavaCompileTests {
6375

6476
public void initializeJava() throws Exception {
6577
FirebaseVertexAI vertex = FirebaseVertexAI.getInstance();
66-
GenerativeModel model = vertex.generativeModel("fake-model-name");
78+
GenerativeModel model = vertex.generativeModel("fake-model-name", getConfig());
79+
LiveGenerativeModel live = vertex.liveModel("fake-model-name", getLiveConfig());
6780
GenerativeModelFutures futures = GenerativeModelFutures.from(model);
81+
LiveModelFutures liveFutures = LiveModelFutures.from(live);
6882
testFutures(futures);
83+
testLiveFutures(liveFutures);
84+
}
85+
86+
private GenerationConfig getConfig() {
87+
return new GenerationConfig.Builder().build();
88+
// TODO b/406558430 GenerationConfig.Builder.setParts returns void
89+
}
90+
91+
private LiveGenerationConfig getLiveConfig() {
92+
return new LiveGenerationConfig.Builder()
93+
.setTopK(10)
94+
.setTopP(11.0F)
95+
.setTemperature(32.0F)
96+
.setCandidateCount(1)
97+
.setMaxOutputTokens(0xCAFEBABE)
98+
.setFrequencyPenalty(1.0F)
99+
.setPresencePenalty(2.0F)
100+
.setResponseModality(ResponseModality.AUDIO)
101+
.setSpeechConfig(new SpeechConfig(Voices.AOEDE))
102+
.build();
69103
}
70104

71105
private void testFutures(GenerativeModelFutures futures) throws Exception {
@@ -236,4 +270,62 @@ public void validateUsageMetadata(UsageMetadata metadata) {
236270
}
237271
}
238272
}
273+
274+
private void testLiveFutures(LiveModelFutures futures) throws Exception {
275+
LiveSessionFutures session = futures.connect().get();
276+
session
277+
.receive()
278+
.subscribe(
279+
new Subscriber<LiveContentResponse>() {
280+
@Override
281+
public void onSubscribe(Subscription s) {
282+
s.request(Long.MAX_VALUE);
283+
}
284+
285+
@Override
286+
public void onNext(LiveContentResponse response) {
287+
validateLiveContentResponse(response);
288+
}
289+
290+
@Override
291+
public void onError(Throwable t) {
292+
// Ignore
293+
}
294+
295+
@Override
296+
public void onComplete() {
297+
// Also ignore
298+
}
299+
});
300+
301+
session.send("Fake message");
302+
session.send(new Content.Builder().addText("Fake message").build());
303+
304+
byte[] bytes = new byte[] {(byte) 0xCA, (byte) 0xFE, (byte) 0xBA, (byte) 0xBE};
305+
session.sendMediaStream(List.of(new MediaData(bytes, "image/jxl")));
306+
307+
FunctionResponsePart functionResponse =
308+
new FunctionResponsePart("myFunction", new JsonObject(Map.of()));
309+
session.sendFunctionResponse(List.of(functionResponse, functionResponse));
310+
311+
session.startAudioConversation(part -> functionResponse);
312+
session.startAudioConversation();
313+
session.stopAudioConversation();
314+
session.stopReceiving();
315+
session.close();
316+
}
317+
318+
private void validateLiveContentResponse(LiveContentResponse response) {
319+
// int status = response.getStatus();
320+
// Assert.assertEquals(status, LiveContentResponse.Status.Companion.getNORMAL());
321+
// Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getINTERRUPTED());
322+
// Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getTURN_COMPLETE());
323+
// TODO b/412743328 LiveContentResponse.Status inaccessible for Java users
324+
Content data = response.getData();
325+
if (data != null) {
326+
validateContent(data);
327+
}
328+
String text = response.getText();
329+
validateFunctionCalls(response.getFunctionCalls());
330+
}
239331
}

0 commit comments

Comments
 (0)