|
17 | 17 | package java.com.google.firebase.ai;
|
18 | 18 |
|
19 | 19 | import android.graphics.Bitmap;
|
| 20 | +import androidx.annotation.Nullable; |
20 | 21 | import com.google.common.util.concurrent.ListenableFuture;
|
21 | 22 | import com.google.firebase.ai.FirebaseAI;
|
22 | 23 | import com.google.firebase.ai.GenerativeModel;
|
|
43 | 44 | import com.google.firebase.ai.type.HarmSeverity;
|
44 | 45 | import com.google.firebase.ai.type.ImagePart;
|
45 | 46 | import com.google.firebase.ai.type.InlineDataPart;
|
46 |
| -import com.google.firebase.ai.type.LiveContentResponse; |
47 | 47 | import com.google.firebase.ai.type.LiveGenerationConfig;
|
| 48 | +import com.google.firebase.ai.type.LiveServerContent; |
| 49 | +import com.google.firebase.ai.type.LiveServerMessage; |
| 50 | +import com.google.firebase.ai.type.LiveServerSetupComplete; |
| 51 | +import com.google.firebase.ai.type.LiveServerToolCall; |
| 52 | +import com.google.firebase.ai.type.LiveServerToolCallCancellation; |
48 | 53 | import com.google.firebase.ai.type.MediaData;
|
49 | 54 | import com.google.firebase.ai.type.ModalityTokenCount;
|
50 | 55 | import com.google.firebase.ai.type.Part;
|
@@ -196,7 +201,10 @@ public void validateCandidates(List<Candidate> candidates) {
|
196 | 201 | }
|
197 | 202 | }
|
198 | 203 |
|
199 |
| - public void validateContent(Content content) { |
| 204 | + public void validateContent(@Nullable Content content) { |
| 205 | + if (content == null) { |
| 206 | + return; |
| 207 | + } |
200 | 208 | String role = content.getRole();
|
201 | 209 | for (Part part : content.getParts()) {
|
202 | 210 | if (part instanceof TextPart) {
|
@@ -279,15 +287,15 @@ private void testLiveFutures(LiveModelFutures futures) throws Exception {
|
279 | 287 | session
|
280 | 288 | .receive()
|
281 | 289 | .subscribe(
|
282 |
| - new Subscriber<LiveContentResponse>() { |
| 290 | + new Subscriber<LiveServerMessage>() { |
283 | 291 | @Override
|
284 | 292 | public void onSubscribe(Subscription s) {
|
285 | 293 | s.request(Long.MAX_VALUE);
|
286 | 294 | }
|
287 | 295 |
|
288 | 296 | @Override
|
289 |
| - public void onNext(LiveContentResponse response) { |
290 |
| - validateLiveContentResponse(response); |
| 297 | + public void onNext(LiveServerMessage message) { |
| 298 | + validateLiveContentResponse(message); |
291 | 299 | }
|
292 | 300 |
|
293 | 301 | @Override
|
@@ -318,17 +326,22 @@ public void onComplete() {
|
318 | 326 | session.close();
|
319 | 327 | }
|
320 | 328 |
|
321 |
| - private void validateLiveContentResponse(LiveContentResponse response) { |
322 |
| - // int status = response.getStatus(); |
323 |
| - // Assert.assertEquals(status, LiveContentResponse.Status.Companion.getNORMAL()); |
324 |
| - // Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getINTERRUPTED()); |
325 |
| - // Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getTURN_COMPLETE()); |
326 |
| - // TODO b/412743328 LiveContentResponse.Status inaccessible for Java users |
327 |
| - Content data = response.getData(); |
328 |
| - if (data != null) { |
329 |
| - validateContent(data); |
| 329 | + private void validateLiveContentResponse(LiveServerMessage message) { |
| 330 | + if (message instanceof LiveServerContent) { |
| 331 | + LiveServerContent content = (LiveServerContent) message; |
| 332 | + validateContent(content.getContent()); |
| 333 | + boolean complete = content.getGenerationComplete(); |
| 334 | + boolean interrupted = content.getInterrupted(); |
| 335 | + boolean turnComplete = content.getTurnComplete(); |
| 336 | + } else if (message instanceof LiveServerSetupComplete) { |
| 337 | + LiveServerSetupComplete setup = (LiveServerSetupComplete) message; |
| 338 | + // No methods |
| 339 | + } else if (message instanceof LiveServerToolCall) { |
| 340 | + LiveServerToolCall call = (LiveServerToolCall) message; |
| 341 | + validateFunctionCalls(call.getFunctionCalls()); |
| 342 | + } else if (message instanceof LiveServerToolCallCancellation) { |
| 343 | + LiveServerToolCallCancellation cancel = (LiveServerToolCallCancellation) message; |
| 344 | + List<String> functions = cancel.getFunctionIds(); |
330 | 345 | }
|
331 |
| - String text = response.getText(); |
332 |
| - validateFunctionCalls(response.getFunctionCalls()); |
333 | 346 | }
|
334 | 347 | }
|
0 commit comments