Skip to content

Commit 659342c

Browse files
GODRIVER-3470 Correct BSON unmarshaling logic for null values (mongodb#1924)
1 parent e796c82 commit 659342c

File tree

4 files changed

+135
-1
lines changed

4 files changed

+135
-1
lines changed

bson/bsoncodec/default_value_decoders.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1521,7 +1521,13 @@ func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr
15211521
return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val}
15221522
}
15231523

1524-
if vr.Type() == bsontype.Null {
1524+
// If BSON value is null and the go value is a pointer, then don't call
1525+
// UnmarshalBSONValue. Even if the Go pointer is already initialized (i.e.,
1526+
// non-nil), encountering null in BSON will result in the pointer being
1527+
// directly set to nil here. Since the pointer is being replaced with nil,
1528+
// there is no opportunity (or reason) for the custom UnmarshalBSONValue logic
1529+
// to be called.
1530+
if vr.Type() == bsontype.Null && val.Kind() == reflect.Ptr {
15251531
val.Set(reflect.Zero(val.Type()))
15261532

15271533
return vr.ReadNull()

bson/unmarshal.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ type ValueUnmarshaler interface {
4141
// Unmarshal parses the BSON-encoded data and stores the result in the value
4242
// pointed to by val. If val is nil or not a pointer, Unmarshal returns
4343
// InvalidUnmarshalError.
44+
//
45+
// When unmarshaling BSON, if the BSON value is null and the Go value is a
46+
// pointer, the pointer is set to nil without calling UnmarshalBSONValue.
4447
func Unmarshal(data []byte, val interface{}) error {
4548
return UnmarshalWithRegistry(DefaultRegistry, data, val)
4649
}

bson/unmarshal_value_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"go.mongodb.org/mongo-driver/bson/bsoncodec"
1515
"go.mongodb.org/mongo-driver/bson/bsontype"
1616
"go.mongodb.org/mongo-driver/internal/assert"
17+
"go.mongodb.org/mongo-driver/internal/require"
1718
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
1819
)
1920

@@ -93,6 +94,29 @@ func TestUnmarshalValue(t *testing.T) {
9394
})
9495
}
9596

97+
func TestInitializedPointerDataWithBSONNull(t *testing.T) {
98+
// Set up the test case with initialized pointers.
99+
tc := unmarshalBehaviorTestCase{
100+
BSONValuePtrTracker: &unmarshalBSONValueCallTracker{},
101+
BSONPtrTracker: &unmarshalBSONCallTracker{},
102+
}
103+
104+
// Create BSON data where the '*_ptr_tracker' fields are explicitly set to
105+
// null.
106+
bytes := docToBytes(D{
107+
{Key: "bv_ptr_tracker", Value: nil},
108+
{Key: "b_ptr_tracker", Value: nil},
109+
})
110+
111+
// Unmarshal the BSON data into the test case struct. This should set the
112+
// pointer fields to nil due to the BSON null value.
113+
err := Unmarshal(bytes, &tc)
114+
require.NoError(t, err)
115+
116+
assert.Nil(t, tc.BSONValuePtrTracker)
117+
assert.Nil(t, tc.BSONPtrTracker)
118+
}
119+
96120
// tests covering GODRIVER-2779
97121
func BenchmarkSliceCodecUnmarshal(b *testing.B) {
98122
benchmarks := []struct {

bson/unmarshaling_cases_test.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"go.mongodb.org/mongo-driver/bson/bsonrw"
1313
"go.mongodb.org/mongo-driver/bson/bsontype"
14+
"go.mongodb.org/mongo-driver/bson/primitive"
1415
)
1516

1617
type unmarshalingTestCase struct {
@@ -114,6 +115,26 @@ func unmarshalingTestCases() []unmarshalingTestCase {
114115
},
115116
data: docToBytes(D{{"fooBar", int32(10)}}),
116117
},
118+
{
119+
name: "nil pointer and non-pointer type with literal null BSON",
120+
sType: reflect.TypeOf(unmarshalBehaviorTestCase{}),
121+
want: &unmarshalBehaviorTestCase{
122+
BSONValueTracker: unmarshalBSONValueCallTracker{
123+
called: true,
124+
},
125+
BSONValuePtrTracker: nil,
126+
BSONTracker: unmarshalBSONCallTracker{
127+
called: true,
128+
},
129+
BSONPtrTracker: nil,
130+
},
131+
data: docToBytes(D{
132+
{Key: "bv_tracker", Value: nil},
133+
{Key: "bv_ptr_tracker", Value: nil},
134+
{Key: "b_tracker", Value: nil},
135+
{Key: "b_ptr_tracker", Value: nil},
136+
}),
137+
},
117138
// GODRIVER-2252
118139
// Test that a struct of pointer types with UnmarshalBSON functions defined marshal and
119140
// unmarshal to the same Go values when the pointer values are "nil".
@@ -174,6 +195,50 @@ func unmarshalingTestCases() []unmarshalingTestCase {
174195
want: &valNonPtrStruct,
175196
data: docToBytes(valNonPtrStruct),
176197
},
198+
{
199+
name: "nil pointer and non-pointer type with BSON minkey",
200+
sType: reflect.TypeOf(unmarshalBehaviorTestCase{}),
201+
want: &unmarshalBehaviorTestCase{
202+
BSONValueTracker: unmarshalBSONValueCallTracker{
203+
called: true,
204+
},
205+
BSONValuePtrTracker: &unmarshalBSONValueCallTracker{
206+
called: true,
207+
},
208+
BSONTracker: unmarshalBSONCallTracker{
209+
called: true,
210+
},
211+
BSONPtrTracker: nil,
212+
},
213+
data: docToBytes(D{
214+
{Key: "bv_tracker", Value: primitive.MinKey{}},
215+
{Key: "bv_ptr_tracker", Value: primitive.MinKey{}},
216+
{Key: "b_tracker", Value: primitive.MinKey{}},
217+
{Key: "b_ptr_tracker", Value: primitive.MinKey{}},
218+
}),
219+
},
220+
{
221+
name: "nil pointer and non-pointer type with BSON maxkey",
222+
sType: reflect.TypeOf(unmarshalBehaviorTestCase{}),
223+
want: &unmarshalBehaviorTestCase{
224+
BSONValueTracker: unmarshalBSONValueCallTracker{
225+
called: true,
226+
},
227+
BSONValuePtrTracker: &unmarshalBSONValueCallTracker{
228+
called: true,
229+
},
230+
BSONTracker: unmarshalBSONCallTracker{
231+
called: true,
232+
},
233+
BSONPtrTracker: nil,
234+
},
235+
data: docToBytes(D{
236+
{Key: "bv_tracker", Value: primitive.MaxKey{}},
237+
{Key: "bv_ptr_tracker", Value: primitive.MaxKey{}},
238+
{Key: "b_tracker", Value: primitive.MaxKey{}},
239+
{Key: "b_ptr_tracker", Value: primitive.MaxKey{}},
240+
}),
241+
},
177242
}
178243
}
179244

@@ -269,3 +334,39 @@ func (ms *myString) UnmarshalBSON(bytes []byte) error {
269334
*ms = myString(s)
270335
return nil
271336
}
337+
338+
// unmarshalBSONValueCallTracker is a test struct that tracks whether the
339+
// UnmarshalBSONValue method has been called.
340+
type unmarshalBSONValueCallTracker struct {
341+
called bool // called is set to true when UnmarshalBSONValue is invoked.
342+
}
343+
344+
var _ ValueUnmarshaler = &unmarshalBSONValueCallTracker{}
345+
346+
// unmarshalBSONCallTracker is a test struct that tracks whether the
347+
// UnmarshalBSON method has been called.
348+
type unmarshalBSONCallTracker struct {
349+
called bool // called is set to true when UnmarshalBSON is invoked.
350+
}
351+
352+
// Ensure unmarshalBSONCallTracker implements the Unmarshaler interface.
353+
var _ Unmarshaler = &unmarshalBSONCallTracker{}
354+
355+
// unmarshalBehaviorTestCase holds instances of call trackers for testing BSON
356+
// unmarshaling behavior.
357+
type unmarshalBehaviorTestCase struct {
358+
BSONValueTracker unmarshalBSONValueCallTracker `bson:"bv_tracker"` // BSON value unmarshaling by value.
359+
BSONValuePtrTracker *unmarshalBSONValueCallTracker `bson:"bv_ptr_tracker"` // BSON value unmarshaling by pointer.
360+
BSONTracker unmarshalBSONCallTracker `bson:"b_tracker"` // BSON unmarshaling by value.
361+
BSONPtrTracker *unmarshalBSONCallTracker `bson:"b_ptr_tracker"` // BSON unmarshaling by pointer.
362+
}
363+
364+
func (tracker *unmarshalBSONValueCallTracker) UnmarshalBSONValue(bsontype.Type, []byte) error {
365+
tracker.called = true
366+
return nil
367+
}
368+
369+
func (tracker *unmarshalBSONCallTracker) UnmarshalBSON([]byte) error {
370+
tracker.called = true
371+
return nil
372+
}

0 commit comments

Comments
 (0)