Skip to content

utils: export ParseFloat and MustParseFloat wrapping internal utils #3371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions helper/helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package helper

import "github.com/redis/go-redis/v9/internal/util"

func ParseFloat(s string) (float64, error) {
return util.ParseStringToFloat(s)
}

func MustParseFloat(s string) float64 {
return util.MustParseFloat(s)
}
30 changes: 30 additions & 0 deletions internal/util/convert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package util

import (
"fmt"
"math"
"strconv"
)

// ParseFloat parses a Redis RESP3 float reply into a Go float64,
// handling "inf", "-inf", "nan" per Redis conventions.
func ParseStringToFloat(s string) (float64, error) {
switch s {
case "inf":
return math.Inf(1), nil
case "-inf":
return math.Inf(-1), nil
case "nan", "-nan":
return math.NaN(), nil
}
return strconv.ParseFloat(s, 64)
}

// MustParseFloat is like ParseFloat but panics on parse errors.
func MustParseFloat(s string) float64 {
f, err := ParseStringToFloat(s)
if err != nil {
panic(fmt.Sprintf("redis: failed to parse float %q: %v", s, err))
}
return f
}
40 changes: 40 additions & 0 deletions internal/util/convert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package util

import (
"math"
"testing"
)

func TestParseStringToFloat(t *testing.T) {
tests := []struct {
in string
want float64
ok bool
}{
{"1.23", 1.23, true},
{"inf", math.Inf(1), true},
{"-inf", math.Inf(-1), true},
{"nan", math.NaN(), true},
{"oops", 0, false},
}

for _, tc := range tests {
got, err := ParseStringToFloat(tc.in)
if tc.ok {
if err != nil {
t.Fatalf("ParseFloat(%q) error: %v", tc.in, err)
}
if math.IsNaN(tc.want) {
if !math.IsNaN(got) {
t.Errorf("ParseFloat(%q) = %v; want NaN", tc.in, got)
}
} else if got != tc.want {
t.Errorf("ParseFloat(%q) = %v; want %v", tc.in, got, tc.want)
}
} else {
if err == nil {
t.Errorf("ParseFloat(%q) expected error, got nil", tc.in)
}
}
}
}
71 changes: 66 additions & 5 deletions search_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package redis_test

import (
"bytes"
"context"
"encoding/binary"
"fmt"
"strconv"
"math"
"strings"
"time"

. "github.com/bsm/ginkgo/v2"
. "github.com/bsm/gomega"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9/helper"
)

func WaitForIndexing(c *redis.Client, index string) {
Expand All @@ -27,6 +30,14 @@ func WaitForIndexing(c *redis.Client, index string) {
}
}

func encodeFloat32Vector(vec []float32) []byte {
buf := new(bytes.Buffer)
for _, v := range vec {
binary.Write(buf, binary.LittleEndian, v)
}
return buf.Bytes()
}

var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
ctx := context.TODO()
var client *redis.Client
Expand Down Expand Up @@ -693,9 +704,9 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
Expect(err).NotTo(HaveOccurred())
Expect(res).ToNot(BeNil())
Expect(len(res.Rows)).To(BeEquivalentTo(2))
score1, err := strconv.ParseFloat(fmt.Sprintf("%s", res.Rows[0].Fields["__score"]), 64)
score1, err := helper.ParseFloat(fmt.Sprintf("%s", res.Rows[0].Fields["__score"]))
Expect(err).NotTo(HaveOccurred())
score2, err := strconv.ParseFloat(fmt.Sprintf("%s", res.Rows[1].Fields["__score"]), 64)
score2, err := helper.ParseFloat(fmt.Sprintf("%s", res.Rows[1].Fields["__score"]))
Expect(err).NotTo(HaveOccurred())
Expect(score1).To(BeNumerically(">", score2))

Expand All @@ -712,9 +723,9 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
Expect(err).NotTo(HaveOccurred())
Expect(resDM).ToNot(BeNil())
Expect(len(resDM.Rows)).To(BeEquivalentTo(2))
score1DM, err := strconv.ParseFloat(fmt.Sprintf("%s", resDM.Rows[0].Fields["__score"]), 64)
score1DM, err := helper.ParseFloat(fmt.Sprintf("%s", resDM.Rows[0].Fields["__score"]))
Expect(err).NotTo(HaveOccurred())
score2DM, err := strconv.ParseFloat(fmt.Sprintf("%s", resDM.Rows[1].Fields["__score"]), 64)
score2DM, err := helper.ParseFloat(fmt.Sprintf("%s", resDM.Rows[1].Fields["__score"]))
Expect(err).NotTo(HaveOccurred())
Expect(score1DM).To(BeNumerically(">", score2DM))

Expand Down Expand Up @@ -1684,6 +1695,56 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
Expect(resUint8.Docs[0].ID).To(BeEquivalentTo("doc1"))
})

It("should return special float scores in FT.SEARCH vecsim", Label("search", "ftsearch", "vecsim"), func() {
SkipBeforeRedisVersion(7.4, "doesn't work with older redis stack images")

vecField := &redis.FTFlatOptions{
Type: "FLOAT32",
Dim: 2,
DistanceMetric: "IP",
}
_, err := client.FTCreate(ctx, "idx_vec",
&redis.FTCreateOptions{OnHash: true, Prefix: []interface{}{"doc:"}},
&redis.FieldSchema{FieldName: "vector", FieldType: redis.SearchFieldTypeVector, VectorArgs: &redis.FTVectorArgs{FlatOptions: vecField}}).Result()
Expect(err).NotTo(HaveOccurred())
WaitForIndexing(client, "idx_vec")

bigPos := []float32{1e38, 1e38}
bigNeg := []float32{-1e38, -1e38}
nanVec := []float32{float32(math.NaN()), 0}
negNanVec := []float32{float32(math.Copysign(math.NaN(), -1)), 0}

client.HSet(ctx, "doc:1", "vector", encodeFloat32Vector(bigPos))
client.HSet(ctx, "doc:2", "vector", encodeFloat32Vector(bigNeg))
client.HSet(ctx, "doc:3", "vector", encodeFloat32Vector(nanVec))
client.HSet(ctx, "doc:4", "vector", encodeFloat32Vector(negNanVec))

searchOptions := &redis.FTSearchOptions{WithScores: true, Params: map[string]interface{}{"vec": encodeFloat32Vector(bigPos)}}
res, err := client.FTSearchWithArgs(ctx, "idx_vec", "*=>[KNN 4 @vector $vec]", searchOptions).Result()
Expect(err).NotTo(HaveOccurred())
Expect(res.Total).To(BeEquivalentTo(4))

var scores []float64
for _, row := range res.Docs {
raw := fmt.Sprintf("%v", row.Fields["__vector_score"])
f, err := helper.ParseFloat(raw)
Expect(err).NotTo(HaveOccurred())
scores = append(scores, f)
}

Expect(scores).To(ContainElement(BeNumerically("==", math.Inf(1))))
Expect(scores).To(ContainElement(BeNumerically("==", math.Inf(-1))))

// For NaN values, use a custom check since NaN != NaN in floating point math
nanCount := 0
for _, score := range scores {
if math.IsNaN(score) {
nanCount++
}
}
Expect(nanCount).To(Equal(2))
})

It("should fail when using a non-zero offset with a zero limit", Label("search", "ftsearch"), func() {
SkipBeforeRedisVersion(7.9, "requires Redis 8.x")
val, err := client.FTCreate(ctx, "testIdx", &redis.FTCreateOptions{}, &redis.FieldSchema{
Expand Down
Loading