Skip to content

Commit 42c3284

Browse files
authored
utils: export ParseFloat and MustParseFloat wrapping internal utils (#3371)
* utils: expose ParseFloat via new public utils package * add tests for special float values in vector search
1 parent f174acb commit 42c3284

File tree

4 files changed

+147
-5
lines changed

4 files changed

+147
-5
lines changed

helper/helper.go

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package helper
2+
3+
import "github.com/redis/go-redis/v9/internal/util"
4+
5+
func ParseFloat(s string) (float64, error) {
6+
return util.ParseStringToFloat(s)
7+
}
8+
9+
func MustParseFloat(s string) float64 {
10+
return util.MustParseFloat(s)
11+
}

internal/util/convert.go

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package util
2+
3+
import (
4+
"fmt"
5+
"math"
6+
"strconv"
7+
)
8+
9+
// ParseFloat parses a Redis RESP3 float reply into a Go float64,
10+
// handling "inf", "-inf", "nan" per Redis conventions.
11+
func ParseStringToFloat(s string) (float64, error) {
12+
switch s {
13+
case "inf":
14+
return math.Inf(1), nil
15+
case "-inf":
16+
return math.Inf(-1), nil
17+
case "nan", "-nan":
18+
return math.NaN(), nil
19+
}
20+
return strconv.ParseFloat(s, 64)
21+
}
22+
23+
// MustParseFloat is like ParseFloat but panics on parse errors.
24+
func MustParseFloat(s string) float64 {
25+
f, err := ParseStringToFloat(s)
26+
if err != nil {
27+
panic(fmt.Sprintf("redis: failed to parse float %q: %v", s, err))
28+
}
29+
return f
30+
}

internal/util/convert_test.go

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package util
2+
3+
import (
4+
"math"
5+
"testing"
6+
)
7+
8+
func TestParseStringToFloat(t *testing.T) {
9+
tests := []struct {
10+
in string
11+
want float64
12+
ok bool
13+
}{
14+
{"1.23", 1.23, true},
15+
{"inf", math.Inf(1), true},
16+
{"-inf", math.Inf(-1), true},
17+
{"nan", math.NaN(), true},
18+
{"oops", 0, false},
19+
}
20+
21+
for _, tc := range tests {
22+
got, err := ParseStringToFloat(tc.in)
23+
if tc.ok {
24+
if err != nil {
25+
t.Fatalf("ParseFloat(%q) error: %v", tc.in, err)
26+
}
27+
if math.IsNaN(tc.want) {
28+
if !math.IsNaN(got) {
29+
t.Errorf("ParseFloat(%q) = %v; want NaN", tc.in, got)
30+
}
31+
} else if got != tc.want {
32+
t.Errorf("ParseFloat(%q) = %v; want %v", tc.in, got, tc.want)
33+
}
34+
} else {
35+
if err == nil {
36+
t.Errorf("ParseFloat(%q) expected error, got nil", tc.in)
37+
}
38+
}
39+
}
40+
}

search_test.go

+66-5
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
package redis_test
22

33
import (
4+
"bytes"
45
"context"
6+
"encoding/binary"
57
"fmt"
6-
"strconv"
8+
"math"
79
"strings"
810
"time"
911

1012
. "github.com/bsm/ginkgo/v2"
1113
. "github.com/bsm/gomega"
1214
"github.com/redis/go-redis/v9"
15+
"github.com/redis/go-redis/v9/helper"
1316
)
1417

1518
func WaitForIndexing(c *redis.Client, index string) {
@@ -27,6 +30,14 @@ func WaitForIndexing(c *redis.Client, index string) {
2730
}
2831
}
2932

33+
func encodeFloat32Vector(vec []float32) []byte {
34+
buf := new(bytes.Buffer)
35+
for _, v := range vec {
36+
binary.Write(buf, binary.LittleEndian, v)
37+
}
38+
return buf.Bytes()
39+
}
40+
3041
var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
3142
ctx := context.TODO()
3243
var client *redis.Client
@@ -693,9 +704,9 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
693704
Expect(err).NotTo(HaveOccurred())
694705
Expect(res).ToNot(BeNil())
695706
Expect(len(res.Rows)).To(BeEquivalentTo(2))
696-
score1, err := strconv.ParseFloat(fmt.Sprintf("%s", res.Rows[0].Fields["__score"]), 64)
707+
score1, err := helper.ParseFloat(fmt.Sprintf("%s", res.Rows[0].Fields["__score"]))
697708
Expect(err).NotTo(HaveOccurred())
698-
score2, err := strconv.ParseFloat(fmt.Sprintf("%s", res.Rows[1].Fields["__score"]), 64)
709+
score2, err := helper.ParseFloat(fmt.Sprintf("%s", res.Rows[1].Fields["__score"]))
699710
Expect(err).NotTo(HaveOccurred())
700711
Expect(score1).To(BeNumerically(">", score2))
701712

@@ -712,9 +723,9 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
712723
Expect(err).NotTo(HaveOccurred())
713724
Expect(resDM).ToNot(BeNil())
714725
Expect(len(resDM.Rows)).To(BeEquivalentTo(2))
715-
score1DM, err := strconv.ParseFloat(fmt.Sprintf("%s", resDM.Rows[0].Fields["__score"]), 64)
726+
score1DM, err := helper.ParseFloat(fmt.Sprintf("%s", resDM.Rows[0].Fields["__score"]))
716727
Expect(err).NotTo(HaveOccurred())
717-
score2DM, err := strconv.ParseFloat(fmt.Sprintf("%s", resDM.Rows[1].Fields["__score"]), 64)
728+
score2DM, err := helper.ParseFloat(fmt.Sprintf("%s", resDM.Rows[1].Fields["__score"]))
718729
Expect(err).NotTo(HaveOccurred())
719730
Expect(score1DM).To(BeNumerically(">", score2DM))
720731

@@ -1684,6 +1695,56 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
16841695
Expect(resUint8.Docs[0].ID).To(BeEquivalentTo("doc1"))
16851696
})
16861697

1698+
It("should return special float scores in FT.SEARCH vecsim", Label("search", "ftsearch", "vecsim"), func() {
1699+
SkipBeforeRedisVersion(7.4, "doesn't work with older redis stack images")
1700+
1701+
vecField := &redis.FTFlatOptions{
1702+
Type: "FLOAT32",
1703+
Dim: 2,
1704+
DistanceMetric: "IP",
1705+
}
1706+
_, err := client.FTCreate(ctx, "idx_vec",
1707+
&redis.FTCreateOptions{OnHash: true, Prefix: []interface{}{"doc:"}},
1708+
&redis.FieldSchema{FieldName: "vector", FieldType: redis.SearchFieldTypeVector, VectorArgs: &redis.FTVectorArgs{FlatOptions: vecField}}).Result()
1709+
Expect(err).NotTo(HaveOccurred())
1710+
WaitForIndexing(client, "idx_vec")
1711+
1712+
bigPos := []float32{1e38, 1e38}
1713+
bigNeg := []float32{-1e38, -1e38}
1714+
nanVec := []float32{float32(math.NaN()), 0}
1715+
negNanVec := []float32{float32(math.Copysign(math.NaN(), -1)), 0}
1716+
1717+
client.HSet(ctx, "doc:1", "vector", encodeFloat32Vector(bigPos))
1718+
client.HSet(ctx, "doc:2", "vector", encodeFloat32Vector(bigNeg))
1719+
client.HSet(ctx, "doc:3", "vector", encodeFloat32Vector(nanVec))
1720+
client.HSet(ctx, "doc:4", "vector", encodeFloat32Vector(negNanVec))
1721+
1722+
searchOptions := &redis.FTSearchOptions{WithScores: true, Params: map[string]interface{}{"vec": encodeFloat32Vector(bigPos)}}
1723+
res, err := client.FTSearchWithArgs(ctx, "idx_vec", "*=>[KNN 4 @vector $vec]", searchOptions).Result()
1724+
Expect(err).NotTo(HaveOccurred())
1725+
Expect(res.Total).To(BeEquivalentTo(4))
1726+
1727+
var scores []float64
1728+
for _, row := range res.Docs {
1729+
raw := fmt.Sprintf("%v", row.Fields["__vector_score"])
1730+
f, err := helper.ParseFloat(raw)
1731+
Expect(err).NotTo(HaveOccurred())
1732+
scores = append(scores, f)
1733+
}
1734+
1735+
Expect(scores).To(ContainElement(BeNumerically("==", math.Inf(1))))
1736+
Expect(scores).To(ContainElement(BeNumerically("==", math.Inf(-1))))
1737+
1738+
// For NaN values, use a custom check since NaN != NaN in floating point math
1739+
nanCount := 0
1740+
for _, score := range scores {
1741+
if math.IsNaN(score) {
1742+
nanCount++
1743+
}
1744+
}
1745+
Expect(nanCount).To(Equal(2))
1746+
})
1747+
16871748
It("should fail when using a non-zero offset with a zero limit", Label("search", "ftsearch"), func() {
16881749
SkipBeforeRedisVersion(7.9, "requires Redis 8.x")
16891750
val, err := client.FTCreate(ctx, "testIdx", &redis.FTCreateOptions{}, &redis.FieldSchema{

0 commit comments

Comments
 (0)