Skip to content

Ensemble and tensorrt_llm_bls have different results when using accumulate_tokens #520

Closed
@activezhao

Description

@activezhao

System Info

CPU x86_64

GPU NVIDIA L20

TensorRT branch: v0.8.0

CUDA: NVIDIA-SMI 535.154.05 Driver Version: 535.154.05 CUDA Version: 12.3

Who can help?

@kaiyux @byshiue @schetlur-nv

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

When I use accumulate_tokens, I found the same request has different result.

Expected behavior

The results should be the same.

actual behavior

When the prompt and parameters are the same, I use APIs of ensemble and tensorrt_llm_bls, the results are different.

curl -X POST localhost:8820/v2/models/tensorrt_llm_bls/generate_stream

curl -X POST localhost:8820/v2/models/tensorrt_llm_bls/generate_stream -d '{"text_input": "\u003creponame\u003ecommon\n\u003cneighbor\u003e\u003cfilename\u003evalue\u003ccodeblock\u003e// Compare this snippet from waitpush/DrugRemindPush.go:...\u003cneighbor\u003e\u003cfilename\u003ekey\u003ccodeblock\u003eDrugRemindPush.go\u003cfilename\u003edosage_form.go\n\u003c|fim▁begin|\u003e\u003creponame\u003eprogramming-language-demo\n\u003cneighbor\u003e\u003cfilename\u003eprime-number.go\u003ccodeblock\u003e// }\n// func isPrime(n int) bool {\n//     if n \u003c 2 {\n//         return false\n//     } else {\n//         for i := 2; i \u003c= n/2; i++ {\n//             if n%i == 0 {\n//                 return false\n//             }\n//         }\n//     }\n//     return true\n// }\n// Functions from import file go/prime-number.go can be referenced:\n// func exitWithError()\n// func main()\n// func isPrime(n int) bool\n// Compare this snippet from go/prime-number.go:\n// package main\n// \n// import (\n//     \"fmt\"\n//     \"os\"\n//     \"strconv\"\n// )\n// \n// func isPrime(n int) bool {\n//     if n \u003c 2 {\n//         return false\n//     } else {\n//         for i := 2; i \u003c= n/2; i++ {\n//             if n%i == 0 {\n//                 return false\n//             }\n//         }\n//     }\n//     return true\n// }\n// \n// func exitWithError() {\n//     fmt.Println(\"Usage: please input a non-negative integer\")\n//     os.Exit(1)\n// }\n// \n// func main() {\n//     if len(os.Args) != 2 {\n//         exitWithError()\n//     }\n// \n//     n, err := strconv.Atoi(os.Args[1])\n//     if err != nil || n \u003c 0 {\n//         exitWithError()\n//     }\n// \n//     if isPrime(n) {\n//         fmt.Println(\"Prime\")\n//     } else {\n//         fmt.Println(\"Composite\")\n//     }\n// }\u003cneighbor\u003e\u003cfilename\u003eprime-number.go\u003ccodeblock\u003e// Functions from import file go/prime-number.go can be referenced:\n// func exitWithError() {\n//     fmt.Println(\"Usage: please input a non-negative integer\")\n//     os.Exit(1)\n// }\n// func main() {\n//     if len(os.Args) != 2 {\n//         exitWithError()\n//     }\n// \n//     n, err := strconv.Atoi(os.Args[1])\n//     if err != nil || n \u003c 0 {\n//         exitWithError()\n//     }\n// \n//     if isPrime(n) {\n//         fmt.Println(\"Prime\")\n//     } else {\n//         fmt.Println(\"Composite\")\n//     }\n// }\n// func isPrime(n int) bool {\n//     if n \u003c 2 {\n//         return false\n//     } else {\n//         for i := 2; i \u003c= n/2; i++ {\n//             if n%i == 0 {\n//                 return false\n//             }\n//         }\n//     }\n//     return true\n// }\n// Functions from import file go/prime-number.go can be referenced:\n// func exitWithError()\n// func main()\n// func isPrime(n int) bool\n// Compare this snippet from go/prime-number.go:\n// package main\n// \n// import (\n//     \"fmt\"\n//     \"os\"\n//     \"strconv\"\n// )\n// \n// func isPrime(n int) bool {\n//     if n \u003c 2 {\n//         return false\n//     } else {\n//         for i := 2; i \u003c= n/2; i++ {\n//             if n%i == 0 {\n//                 return false\n//             }\n//         }\n//     }\n//     return true\n// }\n// \n// func exitWithError() {\u003cneighbor\u003e\u003cfilename\u003elongest-word.go\u003ccodeblock\u003e// Variables from import file go/longest-word.go can be referenced:\n// errorMessage = \"Usage: please provide a string\"\n// Functions from import file go/longest-word.go can be referenced:\n// func longestWordLength(str string) int {\n//     words := strings.FieldsFunc(str, isLimitedWhitespace)\n//     return longestStringLength(words)\n// }\n// func isLimitedWhitespace(r rune) bool {\n//     return strings.ContainsRune(\" \\t\\n\\r\", r)\n// }\n// func longestStringLength(strs []string) (longest int) {\n//     for _, str := range strs {\n//         if len(str) \u003e longest {\n//             longest = len(str)\n//         }\n//     }\n//     return\n// }\n// Functions from import file go/longest-word.go can be referenced:\n// func longestWordLength(str string) int\n// func isLimitedWhitespace(r rune) bool\n// func longestStringLength(strs []string) (longest int)\u003cneighbor\u003e\u003cfilename\u003efactorial.go\u003ccodeblock\u003e// Functions from import file go/factorial.go can be referenced:\n// func exitWithError(msg string) {\n//     fmt.Println(msg)\n//     os.Exit(1)\n// }\n// func factorial(n uint64) uint64 {\n//     if n \u003c= 0 {\n//         return 1\n//     }\n//     return n * factorial(n-1)\n// }\n// Functions from import file go/factorial.go can be referenced:\n// func exitWithError(msg string)\n// func factorial(n uint64) uint64\u003cfilename\u003elongest-common-subsequence.go\n\u003ccodecontent\u003epackage main\nimport (\n    \"encoding/json\"\n    \"fmt\"\n    \"os\"\n    \"regexp\"\n    \"strconv\"\n    \"strings\"\n)\n//exitWithError\n\u003c|fim▁end|\u003e}\n\u003c|fim▁hole|\u003e", "max_tokens": 50, "bad_words": "", "stop_words": "", "stream": false, "temperature": 0.2, "top_p": 0.95, "return_log_probs": true, "generation_logits": true}'

The result is:

data: {"context_logits":0.0,"cum_log_probs":-77.98719787597656,"generation_logits":0.0,"model_name":"tensorrt_llm_bls","model_version":"1","output_log_probs":[-1.3984918594360352,-3.991654872894287,-2.127605676651001,-0.18318799138069154,-0.15039844810962678,-0.3713747262954712,-2.1666009426116945,-0.03320259973406792,-0.6704073548316956,-3.395005941390991,-6.215298652648926,-3.6144485473632814,-3.8179116249084474,-1.1550722122192383,-1.0524828433990479,-0.32207995653152468,-0.4670903980731964,-5.648696422576904,-3.6973865032196047,-3.8024346828460695,-0.13288161158561707,-3.7232208251953127,-2.065372943878174,-0.026736034080386163,-0.30800527334213259,-0.15478214621543885,-3.5880002975463869,-2.564371109008789,-1.118330717086792,-0.008484973572194577,-1.2587940692901612,-0.5912411212921143,-2.966789484024048,-2.6259653568267824,-0.009489176794886589,-0.018396474421024324,-0.12405481934547425,-2.876150131225586,-0.15892530977725984,-3.3690268993377687,-3.163250684738159,-1.4551129341125489,-0.021045353263616563,-0.0005316358874551952,-0.05893709510564804,-1.1418265104293824,-0.00010598267544992268,-0.03211848437786102,-0.10972829163074494,-0.03469150885939598],"text_output":"//findLCS\n//main\n//func removeWhiteSpace\n//func processCommandLineArgs\n//func main() {\n//    var (\n//        lcs       = findLCS(os.Args[1"}

The part of text_output is:

//findLCS
//main
//func removeWhiteSpace
//func processCommandLineArgs
//func main() {
//    var (
//        lcs       = findLCS(os.Args[1

curl -X POST localhost:8820/v2/models/ensemble/generate_stream

curl -X POST localhost:8820/v2/models/ensemble/generate_stream -d '{"text_input": "\u003creponame\u003ecommon\n\u003cneighbor\u003e\u003cfilename\u003evalue\u003ccodeblock\u003e// Compare this snippet from waitpush/DrugRemindPush.go:...\u003cneighbor\u003e\u003cfilename\u003ekey\u003ccodeblock\u003eDrugRemindPush.go\u003cfilename\u003edosage_form.go\n\u003c|fim▁begin|\u003e\u003creponame\u003eprogramming-language-demo\n\u003cneighbor\u003e\u003cfilename\u003eprime-number.go\u003ccodeblock\u003e// }\n// func isPrime(n int) bool {\n//     if n \u003c 2 {\n//         return false\n//     } else {\n//         for i := 2; i \u003c= n/2; i++ {\n//             if n%i == 0 {\n//                 return false\n//             }\n//         }\n//     }\n//     return true\n// }\n// Functions from import file go/prime-number.go can be referenced:\n// func exitWithError()\n// func main()\n// func isPrime(n int) bool\n// Compare this snippet from go/prime-number.go:\n// package main\n// \n// import (\n//     \"fmt\"\n//     \"os\"\n//     \"strconv\"\n// )\n// \n// func isPrime(n int) bool {\n//     if n \u003c 2 {\n//         return false\n//     } else {\n//         for i := 2; i \u003c= n/2; i++ {\n//             if n%i == 0 {\n//                 return false\n//             }\n//         }\n//     }\n//     return true\n// }\n// \n// func exitWithError() {\n//     fmt.Println(\"Usage: please input a non-negative integer\")\n//     os.Exit(1)\n// }\n// \n// func main() {\n//     if len(os.Args) != 2 {\n//         exitWithError()\n//     }\n// \n//     n, err := strconv.Atoi(os.Args[1])\n//     if err != nil || n \u003c 0 {\n//         exitWithError()\n//     }\n// \n//     if isPrime(n) {\n//         fmt.Println(\"Prime\")\n//     } else {\n//         fmt.Println(\"Composite\")\n//     }\n// }\u003cneighbor\u003e\u003cfilename\u003eprime-number.go\u003ccodeblock\u003e// Functions from import file go/prime-number.go can be referenced:\n// func exitWithError() {\n//     fmt.Println(\"Usage: please input a non-negative integer\")\n//     os.Exit(1)\n// }\n// func main() {\n//     if len(os.Args) != 2 {\n//         exitWithError()\n//     }\n// \n//     n, err := strconv.Atoi(os.Args[1])\n//     if err != nil || n \u003c 0 {\n//         exitWithError()\n//     }\n// \n//     if isPrime(n) {\n//         fmt.Println(\"Prime\")\n//     } else {\n//         fmt.Println(\"Composite\")\n//     }\n// }\n// func isPrime(n int) bool {\n//     if n \u003c 2 {\n//         return false\n//     } else {\n//         for i := 2; i \u003c= n/2; i++ {\n//             if n%i == 0 {\n//                 return false\n//             }\n//         }\n//     }\n//     return true\n// }\n// Functions from import file go/prime-number.go can be referenced:\n// func exitWithError()\n// func main()\n// func isPrime(n int) bool\n// Compare this snippet from go/prime-number.go:\n// package main\n// \n// import (\n//     \"fmt\"\n//     \"os\"\n//     \"strconv\"\n// )\n// \n// func isPrime(n int) bool {\n//     if n \u003c 2 {\n//         return false\n//     } else {\n//         for i := 2; i \u003c= n/2; i++ {\n//             if n%i == 0 {\n//                 return false\n//             }\n//         }\n//     }\n//     return true\n// }\n// \n// func exitWithError() {\u003cneighbor\u003e\u003cfilename\u003elongest-word.go\u003ccodeblock\u003e// Variables from import file go/longest-word.go can be referenced:\n// errorMessage = \"Usage: please provide a string\"\n// Functions from import file go/longest-word.go can be referenced:\n// func longestWordLength(str string) int {\n//     words := strings.FieldsFunc(str, isLimitedWhitespace)\n//     return longestStringLength(words)\n// }\n// func isLimitedWhitespace(r rune) bool {\n//     return strings.ContainsRune(\" \\t\\n\\r\", r)\n// }\n// func longestStringLength(strs []string) (longest int) {\n//     for _, str := range strs {\n//         if len(str) \u003e longest {\n//             longest = len(str)\n//         }\n//     }\n//     return\n// }\n// Functions from import file go/longest-word.go can be referenced:\n// func longestWordLength(str string) int\n// func isLimitedWhitespace(r rune) bool\n// func longestStringLength(strs []string) (longest int)\u003cneighbor\u003e\u003cfilename\u003efactorial.go\u003ccodeblock\u003e// Functions from import file go/factorial.go can be referenced:\n// func exitWithError(msg string) {\n//     fmt.Println(msg)\n//     os.Exit(1)\n// }\n// func factorial(n uint64) uint64 {\n//     if n \u003c= 0 {\n//         return 1\n//     }\n//     return n * factorial(n-1)\n// }\n// Functions from import file go/factorial.go can be referenced:\n// func exitWithError(msg string)\n// func factorial(n uint64) uint64\u003cfilename\u003elongest-common-subsequence.go\n\u003ccodecontent\u003epackage main\nimport (\n    \"encoding/json\"\n    \"fmt\"\n    \"os\"\n    \"regexp\"\n    \"strconv\"\n    \"strings\"\n)\n//exitWithError\n\u003c|fim▁end|\u003e}\n\u003c|fim▁hole|\u003e", "max_tokens": 50, "bad_words": "", "stop_words": "", "stream": false, "temperature": 0.2, "top_p": 0.95, "return_log_probs": true, "generation_logits": true}'

The result is:

data: {"context_logits":0.0,"cum_log_probs":-77.98719787597656,"generation_logits":0.0,"model_name":"tensorrt_llm_bls","model_version":"1","output_log_probs":[-1.3984918594360352,-3.991654872894287,-2.127605676651001,-0.18318799138069154,-0.15039844810962678,-0.3713747262954712,-2.1666009426116945,-0.03320259973406792,-0.6704073548316956,-3.395005941390991,-6.215298652648926,-3.6144485473632814,-3.8179116249084474,-1.1550722122192383,-1.0524828433990479,-0.32207995653152468,-0.4670903980731964,-5.648696422576904,-3.6973865032196047,-3.8024346828460695,-0.13288161158561707,-3.7232208251953127,-2.065372943878174,-0.026736034080386163,-0.30800527334213259,-0.15478214621543885,-3.5880002975463869,-2.564371109008789,-1.118330717086792,-0.008484973572194577,-1.2587940692901612,-0.5912411212921143,-2.966789484024048,-2.6259653568267824,-0.009489176794886589,-0.018396474421024324,-0.12405481934547425,-2.876150131225586,-0.15892530977725984,-3.3690268993377687,-3.163250684738159,-1.4551129341125489,-0.021045353263616563,-0.0005316358874551952,-0.05893709510564804,-1.1418265104293824,-0.00010598267544992268,-0.03211848437786102,-0.10972829163074494,-0.03469150885939598],"text_output":"//findLCS\n//main\n//func removeWhiteSpace\n//func processCommandLineArgs\n//func main() {\n//    var (\n//        lcs       = findLCS(os.Args[1"}

The part of text_output is:

func exitWithError(msg string) {
    fmt.Println(msg)
    os.Exit(1)
}
//longestCommonSubsequence
func longestCommonSubsequence(a, b string) string {

In fact, the result of ensemble is expected.

I also print the output_ids, they are different.
31-802f-6e304f38ace3

additional notes

I'm confused as to why this is happening, I think the results just should be the same.

Is there a way to solve this problem.

Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions