Skip to content

Commit a84f0f8

Browse files
matthewdaleprestonvasquez
authored andcommitted
GODRIVER-3470 Correct BSON unmarshaling logic for null values (#1924) [master] (#1945)
Co-authored-by: Preston Vasquez <[email protected]> (cherry picked from commit 25df82f)
1 parent 80c2f1d commit a84f0f8

File tree

4 files changed

+133
-2
lines changed

4 files changed

+133
-2
lines changed

bson/default_value_decoders.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1166,7 +1166,13 @@ func valueUnmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Va
11661166
return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val}
11671167
}
11681168

1169-
if vr.Type() == TypeNull {
1169+
// If BSON value is null and the go value is a pointer, then don't call
1170+
// UnmarshalBSONValue. Even if the Go pointer is already initialized (i.e.,
1171+
// non-nil), encountering null in BSON will result in the pointer being
1172+
// directly set to nil here. Since the pointer is being replaced with nil,
1173+
// there is no opportunity (or reason) for the custom UnmarshalBSONValue logic
1174+
// to be called.
1175+
if vr.Type() == TypeNull && val.Kind() == reflect.Ptr {
11701176
val.Set(reflect.Zero(val.Type()))
11711177

11721178
return vr.ReadNull()

bson/unmarshal.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ type ValueUnmarshaler interface {
3636
}
3737

3838
// Unmarshal parses the BSON-encoded data and stores the result in the value
39-
// pointed to by val. If val is nil or not a pointer, Unmarshal returns an error.
39+
// pointed to by val. If val is nil or not a pointer, Unmarshal returns an
40+
// error.
41+
//
42+
// When unmarshaling BSON, if the BSON value is null and the Go value is a
43+
// pointer, the pointer is set to nil without calling UnmarshalBSONValue.
4044
func Unmarshal(data []byte, val interface{}) error {
4145
vr := newDocumentReader(bytes.NewReader(data))
4246
if l, err := vr.peekLength(); err != nil {

bson/unmarshal_value_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"testing"
1414

1515
"go.mongodb.org/mongo-driver/v2/internal/assert"
16+
"go.mongodb.org/mongo-driver/v2/internal/require"
1617
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
1718
)
1819

@@ -39,6 +40,26 @@ func TestUnmarshalValue(t *testing.T) {
3940
})
4041
}
4142

43+
func TestInitializedPointerDataWithBSONNull(t *testing.T) {
44+
// Set up the test case with initialized pointers.
45+
tc := unmarshalBehaviorTestCase{
46+
BSONValuePtrTracker: &unmarshalBSONValueCallTracker{},
47+
BSONPtrTracker: &unmarshalBSONCallTracker{},
48+
}
49+
// Create BSON data where the '*_ptr_tracker' fields are explicitly set to
50+
// null.
51+
bytes := docToBytes(D{
52+
{Key: "bv_ptr_tracker", Value: nil},
53+
{Key: "b_ptr_tracker", Value: nil},
54+
})
55+
// Unmarshal the BSON data into the test case struct. This should set the
56+
// pointer fields to nil due to the BSON null value.
57+
err := Unmarshal(bytes, &tc)
58+
require.NoError(t, err)
59+
assert.Nil(t, tc.BSONValuePtrTracker)
60+
assert.Nil(t, tc.BSONPtrTracker)
61+
}
62+
4263
// tests covering GODRIVER-2779
4364
func BenchmarkSliceCodecUnmarshal(b *testing.B) {
4465
benchmarks := []struct {

bson/unmarshaling_cases_test.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,70 @@ func unmarshalingTestCases() []unmarshalingTestCase {
172172
want: &valNonPtrStruct,
173173
data: docToBytes(valNonPtrStruct),
174174
},
175+
{
176+
name: "nil pointer and non-pointer type with literal null BSON",
177+
sType: reflect.TypeOf(unmarshalBehaviorTestCase{}),
178+
want: &unmarshalBehaviorTestCase{
179+
BSONValueTracker: unmarshalBSONValueCallTracker{
180+
called: true,
181+
},
182+
BSONValuePtrTracker: nil,
183+
BSONTracker: unmarshalBSONCallTracker{
184+
called: true,
185+
},
186+
BSONPtrTracker: nil,
187+
},
188+
data: docToBytes(D{
189+
{Key: "bv_tracker", Value: nil},
190+
{Key: "bv_ptr_tracker", Value: nil},
191+
{Key: "b_tracker", Value: nil},
192+
{Key: "b_ptr_tracker", Value: nil},
193+
}),
194+
},
195+
{
196+
name: "nil pointer and non-pointer type with BSON minkey",
197+
sType: reflect.TypeOf(unmarshalBehaviorTestCase{}),
198+
want: &unmarshalBehaviorTestCase{
199+
BSONValueTracker: unmarshalBSONValueCallTracker{
200+
called: true,
201+
},
202+
BSONValuePtrTracker: &unmarshalBSONValueCallTracker{
203+
called: true,
204+
},
205+
BSONTracker: unmarshalBSONCallTracker{
206+
called: true,
207+
},
208+
BSONPtrTracker: nil,
209+
},
210+
data: docToBytes(D{
211+
{Key: "bv_tracker", Value: MinKey{}},
212+
{Key: "bv_ptr_tracker", Value: MinKey{}},
213+
{Key: "b_tracker", Value: MinKey{}},
214+
{Key: "b_ptr_tracker", Value: MinKey{}},
215+
}),
216+
},
217+
{
218+
name: "nil pointer and non-pointer type with BSON maxkey",
219+
sType: reflect.TypeOf(unmarshalBehaviorTestCase{}),
220+
want: &unmarshalBehaviorTestCase{
221+
BSONValueTracker: unmarshalBSONValueCallTracker{
222+
called: true,
223+
},
224+
BSONValuePtrTracker: &unmarshalBSONValueCallTracker{
225+
called: true,
226+
},
227+
BSONTracker: unmarshalBSONCallTracker{
228+
called: true,
229+
},
230+
BSONPtrTracker: nil,
231+
},
232+
data: docToBytes(D{
233+
{Key: "bv_tracker", Value: MaxKey{}},
234+
{Key: "bv_ptr_tracker", Value: MaxKey{}},
235+
{Key: "b_tracker", Value: MaxKey{}},
236+
{Key: "b_ptr_tracker", Value: MaxKey{}},
237+
}),
238+
},
175239
}
176240
}
177241

@@ -267,3 +331,39 @@ func (ms *myString) UnmarshalBSON(b []byte) error {
267331
*ms = myString(s)
268332
return nil
269333
}
334+
335+
// unmarshalBSONValueCallTracker is a test struct that tracks whether the
336+
// UnmarshalBSONValue method has been called.
337+
type unmarshalBSONValueCallTracker struct {
338+
called bool // called is set to true when UnmarshalBSONValue is invoked.
339+
}
340+
341+
var _ ValueUnmarshaler = &unmarshalBSONValueCallTracker{}
342+
343+
// unmarshalBSONCallTracker is a test struct that tracks whether the
344+
// UnmarshalBSON method has been called.
345+
type unmarshalBSONCallTracker struct {
346+
called bool // called is set to true when UnmarshalBSON is invoked.
347+
}
348+
349+
// Ensure unmarshalBSONCallTracker implements the Unmarshaler interface.
350+
var _ Unmarshaler = &unmarshalBSONCallTracker{}
351+
352+
// unmarshalBehaviorTestCase holds instances of call trackers for testing BSON
353+
// unmarshaling behavior.
354+
type unmarshalBehaviorTestCase struct {
355+
BSONValueTracker unmarshalBSONValueCallTracker `bson:"bv_tracker"` // BSON value unmarshaling by value.
356+
BSONValuePtrTracker *unmarshalBSONValueCallTracker `bson:"bv_ptr_tracker"` // BSON value unmarshaling by pointer.
357+
BSONTracker unmarshalBSONCallTracker `bson:"b_tracker"` // BSON unmarshaling by value.
358+
BSONPtrTracker *unmarshalBSONCallTracker `bson:"b_ptr_tracker"` // BSON unmarshaling by pointer.
359+
}
360+
361+
func (tracker *unmarshalBSONValueCallTracker) UnmarshalBSONValue(byte, []byte) error {
362+
tracker.called = true
363+
return nil
364+
}
365+
366+
func (tracker *unmarshalBSONCallTracker) UnmarshalBSON([]byte) error {
367+
tracker.called = true
368+
return nil
369+
}

0 commit comments

Comments
 (0)