Skip to content

Commit e52787c

Browse files
committed
GODRIVER-3302 Handle malformatted message length properly. (#1758)
(cherry picked from commit be25b9a)
1 parent 884fb42 commit e52787c

File tree

4 files changed

+464
-115
lines changed

4 files changed

+464
-115
lines changed

x/mongo/driver/topology/connection.go

Lines changed: 105 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package topology
99
import (
1010
"context"
1111
"crypto/tls"
12+
"encoding/binary"
1213
"errors"
1314
"fmt"
1415
"io"
@@ -18,6 +19,7 @@ import (
1819
"sync/atomic"
1920
"time"
2021

22+
"go.mongodb.org/mongo-driver/v2/internal/csot"
2123
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
2224
"go.mongodb.org/mongo-driver/v2/mongo/address"
2325
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
@@ -80,9 +82,9 @@ type connection struct {
8082
// accessTokens in the OIDC authenticator cache.
8183
oidcTokenGenID uint64
8284

83-
// awaitingResponse indicates that the server response was not completely
85+
// awaitRemainingBytes indicates the size of server response that was not completely
8486
// read before returning the connection to the pool.
85-
awaitingResponse bool
87+
awaitRemainingBytes *int32
8688
}
8789

8890
// newConnection handles the creation of a connection. It does not connect the connection.
@@ -111,11 +113,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
111113
return c
112114
}
113115

114-
// DriverConnectionID returns the driver connection ID.
115-
func (c *connection) DriverConnectionID() int64 {
116-
return c.driverConnectionID
117-
}
118-
119116
// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
120117
// configuration.
121118
func (c *connection) setGenerationNumber() {
@@ -137,6 +134,39 @@ func (c *connection) hasGenerationNumber() bool {
137134
return driverutil.IsServerLoadBalanced(c.desc)
138135
}
139136

137+
func configureTLS(ctx context.Context,
138+
tlsConnSource tlsConnectionSource,
139+
nc net.Conn,
140+
addr address.Address,
141+
config *tls.Config,
142+
ocspOpts *ocsp.VerifyOptions,
143+
) (net.Conn, error) {
144+
// Ensure config.ServerName is always set for SNI.
145+
if config.ServerName == "" {
146+
hostname := addr.String()
147+
colonPos := strings.LastIndex(hostname, ":")
148+
if colonPos == -1 {
149+
colonPos = len(hostname)
150+
}
151+
152+
hostname = hostname[:colonPos]
153+
config.ServerName = hostname
154+
}
155+
156+
client := tlsConnSource.Client(nc, config)
157+
if err := clientHandshake(ctx, client); err != nil {
158+
return nil, err
159+
}
160+
161+
// Only do OCSP verification if TLS verification is requested.
162+
if !config.InsecureSkipVerify {
163+
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
164+
return nil, ocspErr
165+
}
166+
}
167+
return client, nil
168+
}
169+
140170
// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
141171
// handshakes. All errors returned by connect are considered "before the handshake completes" and
142172
// must be handled by calling the appropriate SDAM handshake error handler.
@@ -291,6 +321,10 @@ func (c *connection) closeConnectContext() {
291321
}
292322
}
293323

324+
func (c *connection) cancellationListenerCallback() {
325+
_ = c.close()
326+
}
327+
294328
func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error {
295329
if originalError == nil {
296330
return nil
@@ -313,10 +347,6 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead
313347
return originalError
314348
}
315349

316-
func (c *connection) cancellationListenerCallback() {
317-
_ = c.close()
318-
}
319-
320350
func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
321351
var err error
322352
if atomic.LoadInt64(&c.state) != connConnected {
@@ -377,14 +407,10 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
377407

378408
dst, errMsg, err := c.read(ctx)
379409
if err != nil {
380-
if nerr := net.Error(nil); errors.As(err, &nerr) && nerr.Timeout() {
381-
// If the error was a timeout error, instead of closing the
382-
// connection mark it as awaiting response so the pool can read the
383-
// response before making it available to other operations.
384-
c.awaitingResponse = true
385-
} else {
386-
// Otherwise, and close the connection because we don't know what
387-
// the connection state is.
410+
if c.awaitRemainingBytes == nil {
411+
// If the connection was not marked as awaiting response, use the
412+
// pre-CSOT behavior and close the connection because we don't know
413+
// if there are other bytes left to read.
388414
c.close()
389415
}
390416
message := errMsg
@@ -401,6 +427,26 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
401427
return dst, nil
402428
}
403429

430+
func (c *connection) parseWmSizeBytes(wmSizeBytes [4]byte) (int32, error) {
431+
// read the length as an int32
432+
size := int32(binary.LittleEndian.Uint32(wmSizeBytes[:]))
433+
434+
if size < 4 {
435+
return 0, fmt.Errorf("malformed message length: %d", size)
436+
}
437+
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
438+
// defaultMaxMessageSize instead.
439+
maxMessageSize := c.desc.MaxMessageSize
440+
if maxMessageSize == 0 {
441+
maxMessageSize = defaultMaxMessageSize
442+
}
443+
if uint32(size) > maxMessageSize {
444+
return 0, errResponseTooLarge
445+
}
446+
447+
return size, nil
448+
}
449+
404450
func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, err error) {
405451
go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
406452
defer func() {
@@ -414,36 +460,43 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
414460
}
415461
}()
416462

463+
isCSOTTimeout := func(err error) bool {
464+
// If the error was a timeout error and CSOT is enabled, instead of
465+
// closing the connection mark it as awaiting response so the pool
466+
// can read the response before making it available to other
467+
// operations.
468+
nerr := net.Error(nil)
469+
return errors.As(err, &nerr) && nerr.Timeout() && csot.IsTimeoutContext(ctx)
470+
}
471+
417472
// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
418473
// reslice dst once instead of twice.
419474
var sizeBuf [4]byte
420475

421476
// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
422477
// because there might be more than one wire message waiting to be read, for example when
423478
// reading messages from an exhaust cursor.
424-
_, err = io.ReadFull(c.nc, sizeBuf[:])
479+
n, err := io.ReadFull(c.nc, sizeBuf[:])
425480
if err != nil {
481+
if l := int32(n); l == 0 && isCSOTTimeout(err) {
482+
c.awaitRemainingBytes = &l
483+
}
426484
return nil, "incomplete read of message header", err
427485
}
428-
429-
// read the length as an int32
430-
size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)
431-
432-
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
433-
// defaultMaxMessageSize instead.
434-
maxMessageSize := c.desc.MaxMessageSize
435-
if maxMessageSize == 0 {
436-
maxMessageSize = defaultMaxMessageSize
437-
}
438-
if uint32(size) > maxMessageSize {
439-
return nil, errResponseTooLarge.Error(), errResponseTooLarge
486+
size, err := c.parseWmSizeBytes(sizeBuf)
487+
if err != nil {
488+
return nil, err.Error(), err
440489
}
441490

442491
dst := make([]byte, size)
443492
copy(dst, sizeBuf[:])
444493

445-
_, err = io.ReadFull(c.nc, dst[4:])
494+
n, err = io.ReadFull(c.nc, dst[4:])
446495
if err != nil {
496+
remainingBytes := size - 4 - int32(n)
497+
if remainingBytes > 0 && isCSOTTimeout(err) {
498+
c.awaitRemainingBytes = &remainingBytes
499+
}
447500
return dst, "incomplete read of full message", err
448501
}
449502

@@ -496,10 +549,6 @@ func (c *connection) setCanStream(canStream bool) {
496549
c.canStream = canStream
497550
}
498551

499-
func (c initConnection) supportsStreaming() bool {
500-
return c.canStream
501-
}
502-
503552
func (c *connection) setStreaming(streaming bool) {
504553
c.currentlyStreaming = streaming
505554
}
@@ -508,6 +557,14 @@ func (c *connection) getCurrentlyStreaming() bool {
508557
return c.currentlyStreaming
509558
}
510559

560+
func (c *connection) previousCanceled() bool {
561+
if val := c.prevCanceled.Load(); val != nil {
562+
return val.(bool)
563+
}
564+
565+
return false
566+
}
567+
511568
func (c *connection) ID() string {
512569
return c.id
513570
}
@@ -516,12 +573,17 @@ func (c *connection) ServerConnectionID() *int64 {
516573
return c.serverConnectionID
517574
}
518575

519-
func (c *connection) previousCanceled() bool {
520-
if val := c.prevCanceled.Load(); val != nil {
521-
return val.(bool)
522-
}
576+
// DriverConnectionID returns the driver connection ID.
577+
func (c *connection) DriverConnectionID() int64 {
578+
return c.driverConnectionID
579+
}
523580

524-
return false
581+
func (c *connection) OIDCTokenGenID() uint64 {
582+
return c.oidcTokenGenID
583+
}
584+
585+
func (c *connection) SetOIDCTokenGenID(genID uint64) {
586+
c.oidcTokenGenID = genID
525587
}
526588

527589
// initConnection is an adapter used during connection initialization. It has the minimum
@@ -562,7 +624,7 @@ func (c initConnection) CurrentlyStreaming() bool {
562624
return c.getCurrentlyStreaming()
563625
}
564626
func (c initConnection) SupportsStreaming() bool {
565-
return c.supportsStreaming()
627+
return c.canStream
566628
}
567629

568630
// Connection implements the driver.Connection interface to allow reading and writing wire
@@ -797,39 +859,6 @@ func (c *Connection) DriverConnectionID() int64 {
797859
return c.connection.DriverConnectionID()
798860
}
799861

800-
func configureTLS(ctx context.Context,
801-
tlsConnSource tlsConnectionSource,
802-
nc net.Conn,
803-
addr address.Address,
804-
config *tls.Config,
805-
ocspOpts *ocsp.VerifyOptions,
806-
) (net.Conn, error) {
807-
// Ensure config.ServerName is always set for SNI.
808-
if config.ServerName == "" {
809-
hostname := addr.String()
810-
colonPos := strings.LastIndex(hostname, ":")
811-
if colonPos == -1 {
812-
colonPos = len(hostname)
813-
}
814-
815-
hostname = hostname[:colonPos]
816-
config.ServerName = hostname
817-
}
818-
819-
client := tlsConnSource.Client(nc, config)
820-
if err := clientHandshake(ctx, client); err != nil {
821-
return nil, err
822-
}
823-
824-
// Only do OCSP verification if TLS verification is requested.
825-
if !config.InsecureSkipVerify {
826-
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
827-
return nil, ocspErr
828-
}
829-
}
830-
return client, nil
831-
}
832-
833862
// OIDCTokenGenID returns the OIDC token generation ID.
834863
func (c *Connection) OIDCTokenGenID() uint64 {
835864
return c.oidcTokenGenID
@@ -839,11 +868,3 @@ func (c *Connection) OIDCTokenGenID() uint64 {
839868
func (c *Connection) SetOIDCTokenGenID(genID uint64) {
840869
c.oidcTokenGenID = genID
841870
}
842-
843-
func (c *connection) OIDCTokenGenID() uint64 {
844-
return c.oidcTokenGenID
845-
}
846-
847-
func (c *connection) SetOIDCTokenGenID(genID uint64) {
848-
c.oidcTokenGenID = genID
849-
}

x/mongo/driver/topology/connection_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,23 @@ func TestConnection(t *testing.T) {
393393
}
394394
listener.assertCalledOnce(t)
395395
})
396+
t.Run("size too small errors", func(t *testing.T) {
397+
err := errors.New("malformed message length: 3")
398+
tnc := &testNetConn{readerr: err, buf: []byte{0x03, 0x00, 0x00, 0x00}}
399+
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
400+
listener := newTestCancellationListener(false)
401+
conn.cancellationListener = listener
402+
403+
want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: err.Error()}
404+
_, got := conn.readWireMessage(context.Background())
405+
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
406+
t.Errorf("errors do not match. got %v; want %v", got, want)
407+
}
408+
if !tnc.closed {
409+
t.Errorf("failed to closeConnection net.Conn after error writing bytes.")
410+
}
411+
listener.assertCalledOnce(t)
412+
})
396413
t.Run("full message read errors", func(t *testing.T) {
397414
err := errors.New("Read error")
398415
tnc := &testNetConn{readerr: err, buf: []byte{0x11, 0x00, 0x00, 0x00}}

0 commit comments

Comments
 (0)