Skip to content

Commit 6f40716

Browse files
committed
Prevent panic on self-referencing structs/map/slices during Marshal
1 parent 71ac162 commit 6f40716

10 files changed

+64
-13
lines changed

reflect.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ func _createEncoderOfType(ctx *ctx, typ reflect2.Type) ValEncoder {
270270
kind := typ.Kind()
271271
switch kind {
272272
case reflect.Interface:
273-
return &dynamicEncoder{typ}
273+
return &dynamicEncoder{valType: typ, seen: make(map[unsafe.Pointer]bool, 1)}
274274
case reflect.Struct:
275275
return encoderOfStruct(ctx, typ)
276276
case reflect.Array:

reflect_dynamic.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
11
package jsoniter
22

33
import (
4-
"github.com/modern-go/reflect2"
54
"reflect"
65
"unsafe"
6+
7+
"github.com/modern-go/reflect2"
78
)
89

910
type dynamicEncoder struct {
1011
valType reflect2.Type
12+
seen map[unsafe.Pointer]bool
1113
}
1214

1315
func (encoder *dynamicEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
16+
if encoder.seen[ptr] {
17+
stream.Error = ErrEncounterCycle
18+
return
19+
}
20+
encoder.seen[ptr] = true
21+
defer delete(encoder.seen, ptr)
22+
1423
obj := encoder.valType.UnsafeIndirect(ptr)
1524
stream.WriteVal(obj)
1625
}

reflect_extension.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@ package jsoniter
22

33
import (
44
"fmt"
5-
"github.com/modern-go/reflect2"
65
"reflect"
76
"sort"
87
"strings"
98
"unicode"
109
"unsafe"
10+
11+
"github.com/modern-go/reflect2"
1112
)
1213

1314
var typeDecoders = map[string]ValDecoder{}
@@ -325,7 +326,7 @@ func _getTypeEncoderFromExtension(ctx *ctx, typ reflect2.Type) ValEncoder {
325326
typePtr := typ.(*reflect2.UnsafePtrType)
326327
encoder := typeEncoders[typePtr.Elem().String()]
327328
if encoder != nil {
328-
return &OptionalEncoder{encoder}
329+
return &OptionalEncoder{ValueEncoder: encoder, seen: make(map[unsafe.Pointer]bool, 1)}
329330
}
330331
}
331332
return nil

reflect_optional.go

+14-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
package jsoniter
22

33
import (
4-
"github.com/modern-go/reflect2"
54
"unsafe"
5+
6+
"github.com/modern-go/reflect2"
67
)
78

89
func decoderOfOptional(ctx *ctx, typ reflect2.Type) ValDecoder {
@@ -16,7 +17,7 @@ func encoderOfOptional(ctx *ctx, typ reflect2.Type) ValEncoder {
1617
ptrType := typ.(*reflect2.UnsafePtrType)
1718
elemType := ptrType.Elem()
1819
elemEncoder := encoderOfType(ctx, elemType)
19-
encoder := &OptionalEncoder{elemEncoder}
20+
encoder := &OptionalEncoder{ValueEncoder: elemEncoder, seen: make(map[unsafe.Pointer]bool, 1)}
2021
return encoder
2122
}
2223

@@ -61,13 +62,22 @@ func (decoder *dereferenceDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
6162

6263
type OptionalEncoder struct {
6364
ValueEncoder ValEncoder
65+
seen map[unsafe.Pointer]bool
6466
}
6567

6668
func (encoder *OptionalEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
67-
if *((*unsafe.Pointer)(ptr)) == nil {
69+
ptr = *((*unsafe.Pointer)(ptr))
70+
if encoder.seen[ptr] {
71+
stream.Error = ErrEncounterCycle
72+
return
73+
}
74+
encoder.seen[ptr] = true
75+
defer delete(encoder.seen, ptr)
76+
77+
if ptr == nil {
6878
stream.WriteNil()
6979
} else {
70-
encoder.ValueEncoder.Encode(*((*unsafe.Pointer)(ptr)), stream)
80+
encoder.ValueEncoder.Encode(ptr, stream)
7181
}
7282
}
7383

reflect_struct_encoder.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ package jsoniter
22

33
import (
44
"fmt"
5-
"github.com/modern-go/reflect2"
65
"io"
76
"reflect"
87
"unsafe"
8+
9+
"github.com/modern-go/reflect2"
910
)
1011

1112
func encoderOfStruct(ctx *ctx, typ reflect2.Type) ValEncoder {
@@ -54,7 +55,7 @@ func createCheckIsEmpty(ctx *ctx, typ reflect2.Type) checkIsEmpty {
5455
kind := typ.Kind()
5556
switch kind {
5657
case reflect.Interface:
57-
return &dynamicEncoder{typ}
58+
return &dynamicEncoder{valType: typ, seen: make(map[unsafe.Pointer]bool, 1)}
5859
case reflect.Struct:
5960
return &structEncoder{typ: typ}
6061
case reflect.Array:

stream.go

+3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package jsoniter
22

33
import (
4+
"errors"
45
"io"
56
)
67

8+
var ErrEncounterCycle = errors.New("encountered a cycle")
9+
710
// stream is a io.Writer like object, with JSON specific write functions.
811
// Error is not returned as return value, but stored as Error member on this stream instance.
912
type Stream struct {

value_tests/map_test.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,13 @@ func init() {
5757
"2018-12-14": true
5858
}`,
5959
}, unmarshalCase{
60-
ptr: (*map[customKey]string)(nil),
60+
ptr: (*map[customKey]string)(nil),
6161
input: `{"foo": "bar"}`,
6262
})
63+
64+
selfRecursive := map[string]interface{}{}
65+
selfRecursive["me"] = selfRecursive
66+
marshalSelfRecursiveCases = append(marshalSelfRecursiveCases, selfRecursive)
6367
}
6468

6569
type MyInterface interface {

value_tests/slice_test.go

+4
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,8 @@ func init() {
2424
ptr: (*[]byte)(nil),
2525
input: `"c3ViamVjdHM\/X2Q9MQ=="`,
2626
})
27+
28+
selfRecursive := []interface{}{nil}
29+
selfRecursive[0] = selfRecursive
30+
marshalSelfRecursiveCases = append(marshalSelfRecursiveCases, selfRecursive)
2731
}

value_tests/struct_test.go

+4
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,10 @@ func init() {
205205
"should not marshal",
206206
},
207207
)
208+
209+
selfRecursive := &structRecursive{}
210+
selfRecursive.Me = selfRecursive
211+
marshalSelfRecursiveCases = append(marshalSelfRecursiveCases, selfRecursive)
208212
}
209213

210214
type StructVarious struct {

value_tests/value_test.go

+17-2
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@ package test
33
import (
44
"encoding/json"
55
"fmt"
6-
"github.com/json-iterator/go"
6+
"testing"
7+
8+
jsoniter "github.com/json-iterator/go"
79
"github.com/modern-go/reflect2"
810
"github.com/stretchr/testify/require"
9-
"testing"
1011
)
1112

1213
type unmarshalCase struct {
@@ -22,6 +23,8 @@ var marshalCases = []interface{}{
2223
nil,
2324
}
2425

26+
var marshalSelfRecursiveCases = []interface{}{}
27+
2528
type selectedMarshalCase struct {
2629
marshalCase interface{}
2730
}
@@ -78,3 +81,15 @@ func Test_marshal(t *testing.T) {
7881
})
7982
}
8083
}
84+
85+
func Test_marshal_self_recursive(t *testing.T) {
86+
for i, testCase := range marshalSelfRecursiveCases {
87+
t.Run(fmt.Sprintf("[%v]%s", i, reflect2.TypeOf(testCase).String()), func(t *testing.T) {
88+
should := require.New(t)
89+
_, err1 := json.Marshal(testCase)
90+
should.ErrorContains(err1, "encountered a cycle")
91+
_, err2 := jsoniter.ConfigCompatibleWithStandardLibrary.Marshal(testCase)
92+
should.ErrorContains(err2, "encountered a cycle")
93+
})
94+
}
95+
}

0 commit comments

Comments
 (0)