Skip to content

Commit a04db0e

Browse files
committed
This commit adds streaming options
1 parent 6c33190 commit a04db0e

File tree

5 files changed

+436
-312
lines changed

5 files changed

+436
-312
lines changed

llama_cpp/llama.py

+138-76
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,56 @@ def decode_batch(seq_sizes: List[int]):
946946
else:
947947
return output
948948

949+
def _create_chunk(
950+
self,
951+
completion_id: str,
952+
created: int,
953+
model_name: str,
954+
text: str,
955+
logprobs_or_none: Union[Optional[CompletionLogprobs], None],
956+
include_usage: bool,
957+
index: int,
958+
finish_reason: Union[str, None],
959+
usage: Union[Dict[str, Any], None] = None,
960+
) -> CreateChatCompletionStreamResponse:
961+
"""
962+
Create chunks for streaming API, depending on whether usage is requested or
963+
not they need (or don't need) an additional field
964+
"""
965+
966+
if include_usage:
967+
token = {
968+
"id": completion_id,
969+
"object": "text_completion",
970+
"created": created,
971+
"model": model_name,
972+
"choices": [
973+
{
974+
"text": text,
975+
"index": index,
976+
"logprobs": logprobs_or_none,
977+
"finish_reason": finish_reason,
978+
},
979+
],
980+
"usage": usage,
981+
}
982+
else:
983+
token = {
984+
"id": completion_id,
985+
"object": "text_completion",
986+
"created": created,
987+
"model": model_name,
988+
"choices": [
989+
{
990+
"text": text,
991+
"index": index,
992+
"logprobs": logprobs_or_none,
993+
"finish_reason": finish_reason,
994+
}
995+
],
996+
}
997+
return token
998+
949999
def _create_completion(
9501000
self,
9511001
prompt: Union[str, List[int]],
@@ -963,6 +1013,7 @@ def _create_completion(
9631013
repeat_penalty: float = 1.1,
9641014
top_k: int = 40,
9651015
stream: bool = False,
1016+
stream_include_usage: Optional[bool] = False,
9661017
seed: Optional[int] = None,
9671018
tfs_z: float = 1.0,
9681019
mirostat_mode: int = 0,
@@ -1178,6 +1229,7 @@ def logit_bias_processor(
11781229
break
11791230

11801231
if stream:
1232+
include_usage = stream_include_usage
11811233
remaining_tokens = completion_tokens[returned_tokens:]
11821234
remaining_text = self.detokenize(remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])
11831235
remaining_length = len(remaining_text)
@@ -1242,22 +1294,23 @@ def logit_bias_processor(
12421294
"top_logprobs": [top_logprob],
12431295
}
12441296
returned_tokens += 1
1245-
yield {
1246-
"id": completion_id,
1247-
"object": "text_completion",
1248-
"created": created,
1249-
"model": model_name,
1250-
"choices": [
1251-
{
1252-
"text": self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]).decode(
1253-
"utf-8", errors="ignore"
1254-
),
1255-
"index": 0,
1256-
"logprobs": logprobs_or_none,
1257-
"finish_reason": None,
1258-
}
1259-
],
1260-
}
1297+
text = (
1298+
self.detokenize(
1299+
[token],
1300+
prev_tokens=prompt_tokens
1301+
+ completion_tokens[:returned_tokens],
1302+
).decode("utf-8", errors="ignore"),
1303+
)
1304+
yield self._create_chunk(
1305+
completion_id=completion_id,
1306+
created=created,
1307+
model_name=model_name,
1308+
text=text,
1309+
finish_reason=None,
1310+
index=0,
1311+
logprobs_or_none=logprobs_or_none,
1312+
include_usage=include_usage,
1313+
)
12611314
else:
12621315
while len(remaining_tokens) > 0:
12631316
decode_success = False
@@ -1282,20 +1335,16 @@ def logit_bias_processor(
12821335
remaining_tokens = remaining_tokens[i:]
12831336
returned_tokens += i
12841337

1285-
yield {
1286-
"id": completion_id,
1287-
"object": "text_completion",
1288-
"created": created,
1289-
"model": model_name,
1290-
"choices": [
1291-
{
1292-
"text": ts,
1293-
"index": 0,
1294-
"logprobs": None,
1295-
"finish_reason": None,
1296-
}
1297-
],
1298-
}
1338+
yield self._create_chunk(
1339+
index=0,
1340+
finish_reason=None,
1341+
completion_id=completion_id,
1342+
created=created,
1343+
model_name=model_name,
1344+
text=ts,
1345+
logprobs_or_none=None,
1346+
include_usage=include_usage,
1347+
)
12991348

13001349
if len(completion_tokens) >= max_tokens:
13011350
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
@@ -1362,54 +1411,60 @@ def logit_bias_processor(
13621411
if token_end_position == end - 1:
13631412
break
13641413
returned_tokens += 1
1365-
yield {
1366-
"id": completion_id,
1367-
"object": "text_completion",
1368-
"created": created,
1369-
"model": model_name,
1370-
"choices": [
1371-
{
1372-
"text": last_text[
1373-
: len(last_text) - (token_end_position - end)
1374-
].decode("utf-8", errors="ignore"),
1375-
"index": 0,
1376-
"logprobs": logprobs_or_none,
1377-
"finish_reason": None,
1378-
}
1379-
],
1380-
}
1414+
text = last_text[
1415+
: len(last_text) - (token_end_position - end)
1416+
].decode("utf-8", errors="ignore")
1417+
1418+
yield self._create_chunk(
1419+
completion_id=completion_id,
1420+
created=created,
1421+
model_name=model_name,
1422+
text=text,
1423+
logprobs_or_none=logprobs_or_none,
1424+
include_usage=include_usage,
1425+
index=0,
1426+
finish_reason=None,
1427+
)
13811428
break
13821429
returned_tokens += 1
1383-
yield {
1384-
"id": completion_id,
1385-
"object": "text_completion",
1386-
"created": created,
1387-
"model": model_name,
1388-
"choices": [
1389-
{
1390-
"text": self.detokenize([token]).decode(
1391-
"utf-8", errors="ignore"
1392-
),
1393-
"index": 0,
1394-
"logprobs": logprobs_or_none,
1395-
"finish_reason": None,
1396-
}
1397-
],
1398-
}
1399-
yield {
1400-
"id": completion_id,
1401-
"object": "text_completion",
1402-
"created": created,
1403-
"model": model_name,
1404-
"choices": [
1405-
{
1406-
"text": "",
1407-
"index": 0,
1408-
"logprobs": None,
1409-
"finish_reason": finish_reason,
1410-
}
1411-
],
1412-
}
1430+
text = self.detokenize([token]).decode("utf-8", errors="ignore")
1431+
yield self._create_chunk(
1432+
completion_id=completion_id,
1433+
created=created,
1434+
model_name=model_name,
1435+
text=text,
1436+
logprobs_or_none=logprobs_or_none,
1437+
include_usage=include_usage,
1438+
index=0,
1439+
finish_reason=None,
1440+
)
1441+
yield self._create_chunk(
1442+
completion_id= completion_id,
1443+
created= created,
1444+
model_name=model_name,
1445+
text="",
1446+
index=0,
1447+
logprobs_or_none= None,
1448+
include_usage=include_usage,
1449+
usage=None,
1450+
finish_reason=finish_reason)
1451+
1452+
if include_usage:
1453+
yield self._create_chunk(
1454+
completion_id=completion_id,
1455+
created=created,
1456+
model_name=model_name,
1457+
text="",
1458+
logprobs_or_none=None,
1459+
include_usage=include_usage,
1460+
index=0,
1461+
finish_reason=None,
1462+
usage={
1463+
"prompt_tokens": len(prompt_tokens),
1464+
"completion_tokens": returned_tokens,
1465+
"total_tokens": len(prompt_tokens) + returned_tokens,
1466+
},
1467+
)
14131468
if self.cache:
14141469
if self.verbose:
14151470
print("Llama._create_completion: cache save", file=sys.stderr)
@@ -1510,6 +1565,7 @@ def logit_bias_processor(
15101565
},
15111566
}
15121567

1568+
15131569
def create_completion(
15141570
self,
15151571
prompt: Union[str, List[int]],
@@ -1527,6 +1583,7 @@ def create_completion(
15271583
repeat_penalty: float = 1.1,
15281584
top_k: int = 40,
15291585
stream: bool = False,
1586+
stream_include_usage: bool = False,
15301587
seed: Optional[int] = None,
15311588
tfs_z: float = 1.0,
15321589
mirostat_mode: int = 0,
@@ -1590,6 +1647,7 @@ def create_completion(
15901647
repeat_penalty=repeat_penalty,
15911648
top_k=top_k,
15921649
stream=stream,
1650+
stream_include_usage=stream_include_usage,
15931651
seed=seed,
15941652
tfs_z=tfs_z,
15951653
mirostat_mode=mirostat_mode,
@@ -1624,6 +1682,7 @@ def __call__(
16241682
repeat_penalty: float = 1.1,
16251683
top_k: int = 40,
16261684
stream: bool = False,
1685+
stream_include_usage: Optional[bool] = False,
16271686
seed: Optional[int] = None,
16281687
tfs_z: float = 1.0,
16291688
mirostat_mode: int = 0,
@@ -1687,6 +1746,7 @@ def __call__(
16871746
repeat_penalty=repeat_penalty,
16881747
top_k=top_k,
16891748
stream=stream,
1749+
stream_include_usage=stream_include_usage,
16901750
seed=seed,
16911751
tfs_z=tfs_z,
16921752
mirostat_mode=mirostat_mode,
@@ -1712,6 +1772,7 @@ def create_chat_completion(
17121772
min_p: float = 0.05,
17131773
typical_p: float = 1.0,
17141774
stream: bool = False,
1775+
stream_include_usage: Optional[bool] = False,
17151776
stop: Optional[Union[str, List[str]]] = [],
17161777
seed: Optional[int] = None,
17171778
response_format: Optional[ChatCompletionRequestResponseFormat] = None,
@@ -1783,6 +1844,7 @@ def create_chat_completion(
17831844
logprobs=logprobs,
17841845
top_logprobs=top_logprobs,
17851846
stream=stream,
1847+
stream_include_usage=stream_include_usage,
17861848
stop=stop,
17871849
seed=seed,
17881850
response_format=response_format,

0 commit comments

Comments
 (0)