|
44 | 44 | import com.google.firebase.vertexai.type.HarmSeverity;
|
45 | 45 | import com.google.firebase.vertexai.type.ImagePart;
|
46 | 46 | import com.google.firebase.vertexai.type.InlineDataPart;
|
47 |
| -import com.google.firebase.vertexai.type.LiveContentResponse; |
48 | 47 | import com.google.firebase.vertexai.type.LiveGenerationConfig;
|
| 48 | +import com.google.firebase.vertexai.type.LiveServerContent; |
| 49 | +import com.google.firebase.vertexai.type.LiveServerMessage; |
| 50 | +import com.google.firebase.vertexai.type.LiveServerSetupComplete; |
| 51 | +import com.google.firebase.vertexai.type.LiveServerToolCall; |
| 52 | +import com.google.firebase.vertexai.type.LiveServerToolCallCancellation; |
49 | 53 | import com.google.firebase.vertexai.type.MediaData;
|
50 | 54 | import com.google.firebase.vertexai.type.ModalityTokenCount;
|
51 | 55 | import com.google.firebase.vertexai.type.Part;
|
@@ -276,14 +280,14 @@ private void testLiveFutures(LiveModelFutures futures) throws Exception {
|
276 | 280 | session
|
277 | 281 | .receive()
|
278 | 282 | .subscribe(
|
279 |
| - new Subscriber<LiveContentResponse>() { |
| 283 | + new Subscriber<LiveServerMessage>() { |
280 | 284 | @Override
|
281 | 285 | public void onSubscribe(Subscription s) {
|
282 | 286 | s.request(Long.MAX_VALUE);
|
283 | 287 | }
|
284 | 288 |
|
285 | 289 | @Override
|
286 |
| - public void onNext(LiveContentResponse response) { |
| 290 | + public void onNext(LiveServerMessage response) { |
287 | 291 | validateLiveContentResponse(response);
|
288 | 292 | }
|
289 | 293 |
|
@@ -315,17 +319,22 @@ public void onComplete() {
|
315 | 319 | session.close();
|
316 | 320 | }
|
317 | 321 |
|
318 |
| - private void validateLiveContentResponse(LiveContentResponse response) { |
319 |
| - // int status = response.getStatus(); |
320 |
| - // Assert.assertEquals(status, LiveContentResponse.Status.Companion.getNORMAL()); |
321 |
| - // Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getINTERRUPTED()); |
322 |
| - // Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getTURN_COMPLETE()); |
323 |
| - // TODO b/412743328 LiveContentResponse.Status inaccessible for Java users |
324 |
| - Content data = response.getData(); |
325 |
| - if (data != null) { |
326 |
| - validateContent(data); |
| 322 | + private void validateLiveContentResponse(LiveServerMessage message) { |
| 323 | + if (message instanceof LiveServerContent) { |
| 324 | + LiveServerContent content = (LiveServerContent) message; |
| 325 | + validateContent(content.getContent()); |
| 326 | + boolean complete = content.getGenerationComplete(); |
| 327 | + boolean interrupted = content.getInterrupted(); |
| 328 | + boolean turnComplete = content.getTurnComplete(); |
| 329 | + } else if (message instanceof LiveServerSetupComplete) { |
| 330 | + LiveServerSetupComplete setup = (LiveServerSetupComplete) message; |
| 331 | + // No methods |
| 332 | + } else if (message instanceof LiveServerToolCall) { |
| 333 | + LiveServerToolCall call = (LiveServerToolCall) message; |
| 334 | + validateFunctionCalls(call.getFunctionCalls()); |
| 335 | + } else if (message instanceof LiveServerToolCallCancellation) { |
| 336 | + LiveServerToolCallCancellation cancel = (LiveServerToolCallCancellation) message; |
| 337 | + List<String> functions = cancel.getFunctionIds(); |
327 | 338 | }
|
328 |
| - String text = response.getText(); |
329 |
| - validateFunctionCalls(response.getFunctionCalls()); |
330 | 339 | }
|
331 | 340 | }
|
0 commit comments