@@ -946,6 +946,56 @@ def decode_batch(seq_sizes: List[int]):
946
946
else :
947
947
return output
948
948
949
+ def _create_chunk (
950
+ self ,
951
+ completion_id : str ,
952
+ created : int ,
953
+ model_name : str ,
954
+ text : str ,
955
+ logprobs_or_none : Union [Optional [CompletionLogprobs ], None ],
956
+ include_usage : bool ,
957
+ index : int ,
958
+ finish_reason : Union [str , None ],
959
+ usage : Union [Dict [str , Any ], None ] = None ,
960
+ ) -> CreateChatCompletionStreamResponse :
961
+ """
962
+ Create chunks for streaming API, depending on whether usage is requested or
963
+ not they need (or don't need) an additional field
964
+ """
965
+
966
+ if include_usage :
967
+ token = {
968
+ "id" : completion_id ,
969
+ "object" : "text_completion" ,
970
+ "created" : created ,
971
+ "model" : model_name ,
972
+ "choices" : [
973
+ {
974
+ "text" : text ,
975
+ "index" : index ,
976
+ "logprobs" : logprobs_or_none ,
977
+ "finish_reason" : finish_reason ,
978
+ },
979
+ ],
980
+ "usage" : usage ,
981
+ }
982
+ else :
983
+ token = {
984
+ "id" : completion_id ,
985
+ "object" : "text_completion" ,
986
+ "created" : created ,
987
+ "model" : model_name ,
988
+ "choices" : [
989
+ {
990
+ "text" : text ,
991
+ "index" : index ,
992
+ "logprobs" : logprobs_or_none ,
993
+ "finish_reason" : finish_reason ,
994
+ }
995
+ ],
996
+ }
997
+ return token
998
+
949
999
def _create_completion (
950
1000
self ,
951
1001
prompt : Union [str , List [int ]],
@@ -963,6 +1013,7 @@ def _create_completion(
963
1013
repeat_penalty : float = 1.1 ,
964
1014
top_k : int = 40 ,
965
1015
stream : bool = False ,
1016
+ stream_include_usage : Optional [bool ] = False ,
966
1017
seed : Optional [int ] = None ,
967
1018
tfs_z : float = 1.0 ,
968
1019
mirostat_mode : int = 0 ,
@@ -1178,6 +1229,7 @@ def logit_bias_processor(
1178
1229
break
1179
1230
1180
1231
if stream :
1232
+ include_usage = stream_include_usage
1181
1233
remaining_tokens = completion_tokens [returned_tokens :]
1182
1234
remaining_text = self .detokenize (remaining_tokens , prev_tokens = prompt_tokens + completion_tokens [:returned_tokens ])
1183
1235
remaining_length = len (remaining_text )
@@ -1242,22 +1294,23 @@ def logit_bias_processor(
1242
1294
"top_logprobs" : [top_logprob ],
1243
1295
}
1244
1296
returned_tokens += 1
1245
- yield {
1246
- "id" : completion_id ,
1247
- "object" : "text_completion" ,
1248
- "created" : created ,
1249
- "model" : model_name ,
1250
- "choices" : [
1251
- {
1252
- "text" : self .detokenize ([token ], prev_tokens = prompt_tokens + completion_tokens [:returned_tokens ]).decode (
1253
- "utf-8" , errors = "ignore"
1254
- ),
1255
- "index" : 0 ,
1256
- "logprobs" : logprobs_or_none ,
1257
- "finish_reason" : None ,
1258
- }
1259
- ],
1260
- }
1297
+ text = (
1298
+ self .detokenize (
1299
+ [token ],
1300
+ prev_tokens = prompt_tokens
1301
+ + completion_tokens [:returned_tokens ],
1302
+ ).decode ("utf-8" , errors = "ignore" ),
1303
+ )
1304
+ yield self ._create_chunk (
1305
+ completion_id = completion_id ,
1306
+ created = created ,
1307
+ model_name = model_name ,
1308
+ text = text ,
1309
+ finish_reason = None ,
1310
+ index = 0 ,
1311
+ logprobs_or_none = logprobs_or_none ,
1312
+ include_usage = include_usage ,
1313
+ )
1261
1314
else :
1262
1315
while len (remaining_tokens ) > 0 :
1263
1316
decode_success = False
@@ -1282,20 +1335,16 @@ def logit_bias_processor(
1282
1335
remaining_tokens = remaining_tokens [i :]
1283
1336
returned_tokens += i
1284
1337
1285
- yield {
1286
- "id" : completion_id ,
1287
- "object" : "text_completion" ,
1288
- "created" : created ,
1289
- "model" : model_name ,
1290
- "choices" : [
1291
- {
1292
- "text" : ts ,
1293
- "index" : 0 ,
1294
- "logprobs" : None ,
1295
- "finish_reason" : None ,
1296
- }
1297
- ],
1298
- }
1338
+ yield self ._create_chunk (
1339
+ index = 0 ,
1340
+ finish_reason = None ,
1341
+ completion_id = completion_id ,
1342
+ created = created ,
1343
+ model_name = model_name ,
1344
+ text = ts ,
1345
+ logprobs_or_none = None ,
1346
+ include_usage = include_usage ,
1347
+ )
1299
1348
1300
1349
if len (completion_tokens ) >= max_tokens :
1301
1350
text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
@@ -1362,54 +1411,60 @@ def logit_bias_processor(
1362
1411
if token_end_position == end - 1 :
1363
1412
break
1364
1413
returned_tokens += 1
1365
- yield {
1366
- "id" : completion_id ,
1367
- "object" : "text_completion" ,
1368
- "created" : created ,
1369
- "model" : model_name ,
1370
- "choices" : [
1371
- {
1372
- "text" : last_text [
1373
- : len (last_text ) - (token_end_position - end )
1374
- ].decode ("utf-8" , errors = "ignore" ),
1375
- "index" : 0 ,
1376
- "logprobs" : logprobs_or_none ,
1377
- "finish_reason" : None ,
1378
- }
1379
- ],
1380
- }
1414
+ text = last_text [
1415
+ : len (last_text ) - (token_end_position - end )
1416
+ ].decode ("utf-8" , errors = "ignore" )
1417
+
1418
+ yield self ._create_chunk (
1419
+ completion_id = completion_id ,
1420
+ created = created ,
1421
+ model_name = model_name ,
1422
+ text = text ,
1423
+ logprobs_or_none = logprobs_or_none ,
1424
+ include_usage = include_usage ,
1425
+ index = 0 ,
1426
+ finish_reason = None ,
1427
+ )
1381
1428
break
1382
1429
returned_tokens += 1
1383
- yield {
1384
- "id" : completion_id ,
1385
- "object" : "text_completion" ,
1386
- "created" : created ,
1387
- "model" : model_name ,
1388
- "choices" : [
1389
- {
1390
- "text" : self .detokenize ([token ]).decode (
1391
- "utf-8" , errors = "ignore"
1392
- ),
1393
- "index" : 0 ,
1394
- "logprobs" : logprobs_or_none ,
1395
- "finish_reason" : None ,
1396
- }
1397
- ],
1398
- }
1399
- yield {
1400
- "id" : completion_id ,
1401
- "object" : "text_completion" ,
1402
- "created" : created ,
1403
- "model" : model_name ,
1404
- "choices" : [
1405
- {
1406
- "text" : "" ,
1407
- "index" : 0 ,
1408
- "logprobs" : None ,
1409
- "finish_reason" : finish_reason ,
1410
- }
1411
- ],
1412
- }
1430
+ text = self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
1431
+ yield self ._create_chunk (
1432
+ completion_id = completion_id ,
1433
+ created = created ,
1434
+ model_name = model_name ,
1435
+ text = text ,
1436
+ logprobs_or_none = logprobs_or_none ,
1437
+ include_usage = include_usage ,
1438
+ index = 0 ,
1439
+ finish_reason = None ,
1440
+ )
1441
+ yield self ._create_chunk (
1442
+ completion_id = completion_id ,
1443
+ created = created ,
1444
+ model_name = model_name ,
1445
+ text = "" ,
1446
+ index = 0 ,
1447
+ logprobs_or_none = None ,
1448
+ include_usage = include_usage ,
1449
+ usage = None ,
1450
+ finish_reason = finish_reason )
1451
+
1452
+ if include_usage :
1453
+ yield self ._create_chunk (
1454
+ completion_id = completion_id ,
1455
+ created = created ,
1456
+ model_name = model_name ,
1457
+ text = "" ,
1458
+ logprobs_or_none = None ,
1459
+ include_usage = include_usage ,
1460
+ index = 0 ,
1461
+ finish_reason = None ,
1462
+ usage = {
1463
+ "prompt_tokens" : len (prompt_tokens ),
1464
+ "completion_tokens" : returned_tokens ,
1465
+ "total_tokens" : len (prompt_tokens ) + returned_tokens ,
1466
+ },
1467
+ )
1413
1468
if self .cache :
1414
1469
if self .verbose :
1415
1470
print ("Llama._create_completion: cache save" , file = sys .stderr )
@@ -1510,6 +1565,7 @@ def logit_bias_processor(
1510
1565
},
1511
1566
}
1512
1567
1568
+
1513
1569
def create_completion (
1514
1570
self ,
1515
1571
prompt : Union [str , List [int ]],
@@ -1527,6 +1583,7 @@ def create_completion(
1527
1583
repeat_penalty : float = 1.1 ,
1528
1584
top_k : int = 40 ,
1529
1585
stream : bool = False ,
1586
+ stream_include_usage : bool = False ,
1530
1587
seed : Optional [int ] = None ,
1531
1588
tfs_z : float = 1.0 ,
1532
1589
mirostat_mode : int = 0 ,
@@ -1590,6 +1647,7 @@ def create_completion(
1590
1647
repeat_penalty = repeat_penalty ,
1591
1648
top_k = top_k ,
1592
1649
stream = stream ,
1650
+ stream_include_usage = stream_include_usage ,
1593
1651
seed = seed ,
1594
1652
tfs_z = tfs_z ,
1595
1653
mirostat_mode = mirostat_mode ,
@@ -1624,6 +1682,7 @@ def __call__(
1624
1682
repeat_penalty : float = 1.1 ,
1625
1683
top_k : int = 40 ,
1626
1684
stream : bool = False ,
1685
+ stream_include_usage : Optional [bool ] = False ,
1627
1686
seed : Optional [int ] = None ,
1628
1687
tfs_z : float = 1.0 ,
1629
1688
mirostat_mode : int = 0 ,
@@ -1687,6 +1746,7 @@ def __call__(
1687
1746
repeat_penalty = repeat_penalty ,
1688
1747
top_k = top_k ,
1689
1748
stream = stream ,
1749
+ stream_include_usage = stream_include_usage ,
1690
1750
seed = seed ,
1691
1751
tfs_z = tfs_z ,
1692
1752
mirostat_mode = mirostat_mode ,
@@ -1712,6 +1772,7 @@ def create_chat_completion(
1712
1772
min_p : float = 0.05 ,
1713
1773
typical_p : float = 1.0 ,
1714
1774
stream : bool = False ,
1775
+ stream_include_usage : Optional [bool ] = False ,
1715
1776
stop : Optional [Union [str , List [str ]]] = [],
1716
1777
seed : Optional [int ] = None ,
1717
1778
response_format : Optional [ChatCompletionRequestResponseFormat ] = None ,
@@ -1783,6 +1844,7 @@ def create_chat_completion(
1783
1844
logprobs = logprobs ,
1784
1845
top_logprobs = top_logprobs ,
1785
1846
stream = stream ,
1847
+ stream_include_usage = stream_include_usage ,
1786
1848
stop = stop ,
1787
1849
seed = seed ,
1788
1850
response_format = response_format ,
0 commit comments