Skip to content

Commit 6b9352e

Browse files
committed
Update tests
1 parent 4b4bf27 commit 6b9352e

File tree

1 file changed

+29
-16
lines changed

1 file changed

+29
-16
lines changed

firebase-ai/src/testUtil/java/com/google/firebase/ai/JavaCompileTests.java

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package java.com.google.firebase.ai;
1818

1919
import android.graphics.Bitmap;
20+
import androidx.annotation.Nullable;
2021
import com.google.common.util.concurrent.ListenableFuture;
2122
import com.google.firebase.ai.FirebaseAI;
2223
import com.google.firebase.ai.GenerativeModel;
@@ -43,8 +44,12 @@
4344
import com.google.firebase.ai.type.HarmSeverity;
4445
import com.google.firebase.ai.type.ImagePart;
4546
import com.google.firebase.ai.type.InlineDataPart;
46-
import com.google.firebase.ai.type.LiveContentResponse;
4747
import com.google.firebase.ai.type.LiveGenerationConfig;
48+
import com.google.firebase.ai.type.LiveServerContent;
49+
import com.google.firebase.ai.type.LiveServerMessage;
50+
import com.google.firebase.ai.type.LiveServerSetupComplete;
51+
import com.google.firebase.ai.type.LiveServerToolCall;
52+
import com.google.firebase.ai.type.LiveServerToolCallCancellation;
4853
import com.google.firebase.ai.type.MediaData;
4954
import com.google.firebase.ai.type.ModalityTokenCount;
5055
import com.google.firebase.ai.type.Part;
@@ -196,7 +201,10 @@ public void validateCandidates(List<Candidate> candidates) {
196201
}
197202
}
198203

199-
public void validateContent(Content content) {
204+
public void validateContent(@Nullable Content content) {
205+
if (content == null) {
206+
return;
207+
}
200208
String role = content.getRole();
201209
for (Part part : content.getParts()) {
202210
if (part instanceof TextPart) {
@@ -279,15 +287,15 @@ private void testLiveFutures(LiveModelFutures futures) throws Exception {
279287
session
280288
.receive()
281289
.subscribe(
282-
new Subscriber<LiveContentResponse>() {
290+
new Subscriber<LiveServerMessage>() {
283291
@Override
284292
public void onSubscribe(Subscription s) {
285293
s.request(Long.MAX_VALUE);
286294
}
287295

288296
@Override
289-
public void onNext(LiveContentResponse response) {
290-
validateLiveContentResponse(response);
297+
public void onNext(LiveServerMessage message) {
298+
validateLiveContentResponse(message);
291299
}
292300

293301
@Override
@@ -318,17 +326,22 @@ public void onComplete() {
318326
session.close();
319327
}
320328

321-
private void validateLiveContentResponse(LiveContentResponse response) {
322-
// int status = response.getStatus();
323-
// Assert.assertEquals(status, LiveContentResponse.Status.Companion.getNORMAL());
324-
// Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getINTERRUPTED());
325-
// Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getTURN_COMPLETE());
326-
// TODO b/412743328 LiveContentResponse.Status inaccessible for Java users
327-
Content data = response.getData();
328-
if (data != null) {
329-
validateContent(data);
329+
private void validateLiveContentResponse(LiveServerMessage message) {
330+
if (message instanceof LiveServerContent) {
331+
LiveServerContent content = (LiveServerContent) message;
332+
validateContent(content.getContent());
333+
boolean complete = content.getGenerationComplete();
334+
boolean interrupted = content.getInterrupted();
335+
boolean turnComplete = content.getTurnComplete();
336+
} else if (message instanceof LiveServerSetupComplete) {
337+
LiveServerSetupComplete setup = (LiveServerSetupComplete) message;
338+
// No methods
339+
} else if (message instanceof LiveServerToolCall) {
340+
LiveServerToolCall call = (LiveServerToolCall) message;
341+
validateFunctionCalls(call.getFunctionCalls());
342+
} else if (message instanceof LiveServerToolCallCancellation) {
343+
LiveServerToolCallCancellation cancel = (LiveServerToolCallCancellation) message;
344+
List<String> functions = cancel.getFunctionIds();
330345
}
331-
String text = response.getText();
332-
validateFunctionCalls(response.getFunctionCalls());
333346
}
334347
}

0 commit comments

Comments
 (0)