diff --git a/bson/bsoncodec/default_value_decoders.go b/bson/bsoncodec/default_value_decoders.go index 159297ef0a..7ae7638caf 100644 --- a/bson/bsoncodec/default_value_decoders.go +++ b/bson/bsoncodec/default_value_decoders.go @@ -1521,7 +1521,13 @@ func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} } - if vr.Type() == bsontype.Null { + // If BSON value is null and the go value is a pointer, then don't call + // UnmarshalBSONValue. Even if the Go pointer is already initialized (i.e., + // non-nil), encountering null in BSON will result in the pointer being + // directly set to nil here. Since the pointer is being replaced with nil, + // there is no opportunity (or reason) for the custom UnmarshalBSONValue logic + // to be called. + if vr.Type() == bsontype.Null && val.Kind() == reflect.Ptr { val.Set(reflect.Zero(val.Type())) return vr.ReadNull() @@ -1563,6 +1569,18 @@ func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(_ DecodeContext, vr bsonr return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val} } + // If BSON value is null and the go value is a pointer, then don't call + // UnmarshalBSON. Even if the Go pointer is already initialized (i.e., + // non-nil), encountering null in BSON will result in the pointer being + // directly set to nil here. Since the pointer is being replaced with nil, + // there is no opportunity (or reason) for the custom UnmarshalBSON logic to + // be called. + if val.Kind() == reflect.Ptr && vr.Type() == bsontype.Null { + val.Set(reflect.Zero(val.Type())) + + return vr.ReadNull() + } + if val.Kind() == reflect.Ptr && val.IsNil() { if !val.CanSet() { return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val} @@ -1575,18 +1593,6 @@ func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(_ DecodeContext, vr bsonr return err } - // If the target Go value is a pointer and the BSON field value is empty, set the value to the - // zero value of the pointer (nil) and don't call UnmarshalBSON. UnmarshalBSON has no way to - // change the pointer value from within the function (only the value at the pointer address), - // so it can't set the pointer to "nil" itself. Since the most common Go value for an empty BSON - // field value is "nil", we set "nil" here and don't call UnmarshalBSON. This behavior matches - // the behavior of the Go "encoding/json" unmarshaler when the target Go value is a pointer and - // the JSON field value is "null". - if val.Kind() == reflect.Ptr && len(src) == 0 { - val.Set(reflect.Zero(val.Type())) - return nil - } - if !val.Type().Implements(tUnmarshaler) { if !val.CanAddr() { return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val} diff --git a/bson/unmarshal.go b/bson/unmarshal.go index 66da17ee01..d749ba373b 100644 --- a/bson/unmarshal.go +++ b/bson/unmarshal.go @@ -41,6 +41,9 @@ type ValueUnmarshaler interface { // Unmarshal parses the BSON-encoded data and stores the result in the value // pointed to by val. If val is nil or not a pointer, Unmarshal returns // InvalidUnmarshalError. +// +// When unmarshaling BSON, if the BSON value is null and the Go value is a +// pointer, the pointer is set to nil without calling UnmarshalBSONValue. func Unmarshal(data []byte, val interface{}) error { return UnmarshalWithRegistry(DefaultRegistry, data, val) } diff --git a/bson/unmarshal_value_test.go b/bson/unmarshal_value_test.go index ef91da1659..3455deeaaa 100644 --- a/bson/unmarshal_value_test.go +++ b/bson/unmarshal_value_test.go @@ -14,6 +14,7 @@ import ( "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) @@ -93,6 +94,29 @@ func TestUnmarshalValue(t *testing.T) { }) } +func TestInitializedPointerDataWithBSONNull(t *testing.T) { + // Set up the test case with initialized pointers. + tc := unmarshalBehaviorTestCase{ + BSONValuePtrTracker: &unmarshalBSONValueCallTracker{}, + BSONPtrTracker: &unmarshalBSONCallTracker{}, + } + + // Create BSON data where the '*_ptr_tracker' fields are explicitly set to + // null. + bytes := docToBytes(D{ + {Key: "bv_ptr_tracker", Value: nil}, + {Key: "b_ptr_tracker", Value: nil}, + }) + + // Unmarshal the BSON data into the test case struct. This should set the + // pointer fields to nil due to the BSON null value. + err := Unmarshal(bytes, &tc) + require.NoError(t, err) + + assert.Nil(t, tc.BSONValuePtrTracker) + assert.Nil(t, tc.BSONPtrTracker) +} + // tests covering GODRIVER-2779 func BenchmarkSliceCodecUnmarshal(b *testing.B) { benchmarks := []struct { diff --git a/bson/unmarshaling_cases_test.go b/bson/unmarshaling_cases_test.go index dd38369bff..e9088f219c 100644 --- a/bson/unmarshaling_cases_test.go +++ b/bson/unmarshaling_cases_test.go @@ -114,6 +114,26 @@ func unmarshalingTestCases() []unmarshalingTestCase { }, data: docToBytes(D{{"fooBar", int32(10)}}), }, + { + name: "nil pointer and non-pointer type with literal null BSON", + sType: reflect.TypeOf(unmarshalBehaviorTestCase{}), + want: &unmarshalBehaviorTestCase{ + BSONValueTracker: unmarshalBSONValueCallTracker{ + called: true, + }, + BSONValuePtrTracker: nil, + BSONTracker: unmarshalBSONCallTracker{ + called: true, + }, + BSONPtrTracker: nil, + }, + data: docToBytes(D{ + {Key: "bv_tracker", Value: nil}, + {Key: "bv_ptr_tracker", Value: nil}, + {Key: "b_tracker", Value: nil}, + {Key: "b_ptr_tracker", Value: nil}, + }), + }, // GODRIVER-2252 // Test that a struct of pointer types with UnmarshalBSON functions defined marshal and // unmarshal to the same Go values when the pointer values are "nil". @@ -269,3 +289,39 @@ func (ms *myString) UnmarshalBSON(bytes []byte) error { *ms = myString(s) return nil } + +// unmarshalBSONValueCallTracker is a test struct that tracks whether the +// UnmarshalBSONValue method has been called. +type unmarshalBSONValueCallTracker struct { + called bool // called is set to true when UnmarshalBSONValue is invoked. +} + +var _ ValueUnmarshaler = &unmarshalBSONValueCallTracker{} + +// unmarshalBSONCallTracker is a test struct that tracks whether the +// UnmarshalBSON method has been called. +type unmarshalBSONCallTracker struct { + called bool // called is set to true when UnmarshalBSON is invoked. +} + +// Ensure unmarshalBSONCallTracker implements the Unmarshaler interface. +var _ Unmarshaler = &unmarshalBSONCallTracker{} + +// unmarshalBehaviorTestCase holds instances of call trackers for testing BSON +// unmarshaling behavior. +type unmarshalBehaviorTestCase struct { + BSONValueTracker unmarshalBSONValueCallTracker `bson:"bv_tracker"` // BSON value unmarshaling by value. + BSONValuePtrTracker *unmarshalBSONValueCallTracker `bson:"bv_ptr_tracker"` // BSON value unmarshaling by pointer. + BSONTracker unmarshalBSONCallTracker `bson:"b_tracker"` // BSON unmarshaling by value. + BSONPtrTracker *unmarshalBSONCallTracker `bson:"b_ptr_tracker"` // BSON unmarshaling by pointer. +} + +func (tracker *unmarshalBSONValueCallTracker) UnmarshalBSONValue(bsontype.Type, []byte) error { + tracker.called = true + return nil +} + +func (tracker *unmarshalBSONCallTracker) UnmarshalBSON([]byte) error { + tracker.called = true + return nil +}