Skip to content

Commit be25b9a

Browse files
authored
GODRIVER-3302 Handle malformatted message length properly. (#1758)
1 parent 4757f44 commit be25b9a

File tree

4 files changed

+460
-112
lines changed

4 files changed

+460
-112
lines changed

x/mongo/driver/topology/connection.go

+100-81
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package topology
99
import (
1010
"context"
1111
"crypto/tls"
12+
"encoding/binary"
1213
"errors"
1314
"fmt"
1415
"io"
@@ -79,9 +80,9 @@ type connection struct {
7980
driverConnectionID uint64
8081
generation uint64
8182

82-
// awaitingResponse indicates that the server response was not completely
83+
// awaitRemainingBytes indicates the size of server response that was not completely
8384
// read before returning the connection to the pool.
84-
awaitingResponse bool
85+
awaitRemainingBytes *int32
8586

8687
// oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate
8788
// accessTokens in the OIDC authenticator cache.
@@ -115,12 +116,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
115116
return c
116117
}
117118

118-
// DriverConnectionID returns the driver connection ID.
119-
// TODO(GODRIVER-2824): change return type to int64.
120-
func (c *connection) DriverConnectionID() uint64 {
121-
return c.driverConnectionID
122-
}
123-
124119
// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
125120
// configuration.
126121
func (c *connection) setGenerationNumber() {
@@ -142,6 +137,39 @@ func (c *connection) hasGenerationNumber() bool {
142137
return c.desc.LoadBalanced()
143138
}
144139

140+
func configureTLS(ctx context.Context,
141+
tlsConnSource tlsConnectionSource,
142+
nc net.Conn,
143+
addr address.Address,
144+
config *tls.Config,
145+
ocspOpts *ocsp.VerifyOptions,
146+
) (net.Conn, error) {
147+
// Ensure config.ServerName is always set for SNI.
148+
if config.ServerName == "" {
149+
hostname := addr.String()
150+
colonPos := strings.LastIndex(hostname, ":")
151+
if colonPos == -1 {
152+
colonPos = len(hostname)
153+
}
154+
155+
hostname = hostname[:colonPos]
156+
config.ServerName = hostname
157+
}
158+
159+
client := tlsConnSource.Client(nc, config)
160+
if err := clientHandshake(ctx, client); err != nil {
161+
return nil, err
162+
}
163+
164+
// Only do OCSP verification if TLS verification is requested.
165+
if !config.InsecureSkipVerify {
166+
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
167+
return nil, ocspErr
168+
}
169+
}
170+
return client, nil
171+
}
172+
145173
// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
146174
// handshakes. All errors returned by connect are considered "before the handshake completes" and
147175
// must be handled by calling the appropriate SDAM handshake error handler.
@@ -317,6 +345,10 @@ func (c *connection) closeConnectContext() {
317345
}
318346
}
319347

348+
func (c *connection) cancellationListenerCallback() {
349+
_ = c.close()
350+
}
351+
320352
func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error {
321353
if originalError == nil {
322354
return nil
@@ -339,10 +371,6 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead
339371
return originalError
340372
}
341373

342-
func (c *connection) cancellationListenerCallback() {
343-
_ = c.close()
344-
}
345-
346374
func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
347375
var err error
348376
if atomic.LoadInt64(&c.state) != connConnected {
@@ -423,15 +451,10 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
423451

424452
dst, errMsg, err := c.read(ctx)
425453
if err != nil {
426-
if nerr := net.Error(nil); errors.As(err, &nerr) && nerr.Timeout() && csot.IsTimeoutContext(ctx) {
427-
// If the error was a timeout error and CSOT is enabled, instead of
428-
// closing the connection mark it as awaiting response so the pool
429-
// can read the response before making it available to other
430-
// operations.
431-
c.awaitingResponse = true
432-
} else {
433-
// Otherwise, use the pre-CSOT behavior and close the connection
434-
// because we don't know if there are other bytes left to read.
454+
if c.awaitRemainingBytes == nil {
455+
// If the connection was not marked as awaiting response, use the
456+
// pre-CSOT behavior and close the connection because we don't know
457+
// if there are other bytes left to read.
435458
c.close()
436459
}
437460
message := errMsg
@@ -448,6 +471,26 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
448471
return dst, nil
449472
}
450473

474+
func (c *connection) parseWmSizeBytes(wmSizeBytes [4]byte) (int32, error) {
475+
// read the length as an int32
476+
size := int32(binary.LittleEndian.Uint32(wmSizeBytes[:]))
477+
478+
if size < 4 {
479+
return 0, fmt.Errorf("malformed message length: %d", size)
480+
}
481+
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
482+
// defaultMaxMessageSize instead.
483+
maxMessageSize := c.desc.MaxMessageSize
484+
if maxMessageSize == 0 {
485+
maxMessageSize = defaultMaxMessageSize
486+
}
487+
if uint32(size) > maxMessageSize {
488+
return 0, errResponseTooLarge
489+
}
490+
491+
return size, nil
492+
}
493+
451494
func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, err error) {
452495
go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
453496
defer func() {
@@ -461,36 +504,43 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
461504
}
462505
}()
463506

507+
isCSOTTimeout := func(err error) bool {
508+
// If the error was a timeout error and CSOT is enabled, instead of
509+
// closing the connection mark it as awaiting response so the pool
510+
// can read the response before making it available to other
511+
// operations.
512+
nerr := net.Error(nil)
513+
return errors.As(err, &nerr) && nerr.Timeout() && csot.IsTimeoutContext(ctx)
514+
}
515+
464516
// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
465517
// reslice dst once instead of twice.
466518
var sizeBuf [4]byte
467519

468520
// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
469521
// because there might be more than one wire message waiting to be read, for example when
470522
// reading messages from an exhaust cursor.
471-
_, err = io.ReadFull(c.nc, sizeBuf[:])
523+
n, err := io.ReadFull(c.nc, sizeBuf[:])
472524
if err != nil {
525+
if l := int32(n); l == 0 && isCSOTTimeout(err) {
526+
c.awaitRemainingBytes = &l
527+
}
473528
return nil, "incomplete read of message header", err
474529
}
475-
476-
// read the length as an int32
477-
size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)
478-
479-
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
480-
// defaultMaxMessageSize instead.
481-
maxMessageSize := c.desc.MaxMessageSize
482-
if maxMessageSize == 0 {
483-
maxMessageSize = defaultMaxMessageSize
484-
}
485-
if uint32(size) > maxMessageSize {
486-
return nil, errResponseTooLarge.Error(), errResponseTooLarge
530+
size, err := c.parseWmSizeBytes(sizeBuf)
531+
if err != nil {
532+
return nil, err.Error(), err
487533
}
488534

489535
dst := make([]byte, size)
490536
copy(dst, sizeBuf[:])
491537

492-
_, err = io.ReadFull(c.nc, dst[4:])
538+
n, err = io.ReadFull(c.nc, dst[4:])
493539
if err != nil {
540+
remainingBytes := size - 4 - int32(n)
541+
if remainingBytes > 0 && isCSOTTimeout(err) {
542+
c.awaitRemainingBytes = &remainingBytes
543+
}
494544
return dst, "incomplete read of full message", err
495545
}
496546

@@ -537,10 +587,6 @@ func (c *connection) setCanStream(canStream bool) {
537587
c.canStream = canStream
538588
}
539589

540-
func (c initConnection) supportsStreaming() bool {
541-
return c.canStream
542-
}
543-
544590
func (c *connection) setStreaming(streaming bool) {
545591
c.currentlyStreaming = streaming
546592
}
@@ -554,6 +600,12 @@ func (c *connection) setSocketTimeout(timeout time.Duration) {
554600
c.writeTimeout = timeout
555601
}
556602

603+
// DriverConnectionID returns the driver connection ID.
604+
// TODO(GODRIVER-2824): change return type to int64.
605+
func (c *connection) DriverConnectionID() uint64 {
606+
return c.driverConnectionID
607+
}
608+
557609
func (c *connection) ID() string {
558610
return c.id
559611
}
@@ -562,6 +614,14 @@ func (c *connection) ServerConnectionID() *int64 {
562614
return c.serverConnectionID
563615
}
564616

617+
func (c *connection) OIDCTokenGenID() uint64 {
618+
return c.oidcTokenGenID
619+
}
620+
621+
func (c *connection) SetOIDCTokenGenID(genID uint64) {
622+
c.oidcTokenGenID = genID
623+
}
624+
565625
// initConnection is an adapter used during connection initialization. It has the minimum
566626
// functionality necessary to implement the driver.Connection interface, which is required to pass a
567627
// *connection to a Handshaker.
@@ -599,7 +659,7 @@ func (c initConnection) CurrentlyStreaming() bool {
599659
return c.getCurrentlyStreaming()
600660
}
601661
func (c initConnection) SupportsStreaming() bool {
602-
return c.supportsStreaming()
662+
return c.canStream
603663
}
604664

605665
// Connection implements the driver.Connection interface to allow reading and writing wire
@@ -833,39 +893,6 @@ func (c *Connection) DriverConnectionID() uint64 {
833893
return c.connection.DriverConnectionID()
834894
}
835895

836-
func configureTLS(ctx context.Context,
837-
tlsConnSource tlsConnectionSource,
838-
nc net.Conn,
839-
addr address.Address,
840-
config *tls.Config,
841-
ocspOpts *ocsp.VerifyOptions,
842-
) (net.Conn, error) {
843-
// Ensure config.ServerName is always set for SNI.
844-
if config.ServerName == "" {
845-
hostname := addr.String()
846-
colonPos := strings.LastIndex(hostname, ":")
847-
if colonPos == -1 {
848-
colonPos = len(hostname)
849-
}
850-
851-
hostname = hostname[:colonPos]
852-
config.ServerName = hostname
853-
}
854-
855-
client := tlsConnSource.Client(nc, config)
856-
if err := clientHandshake(ctx, client); err != nil {
857-
return nil, err
858-
}
859-
860-
// Only do OCSP verification if TLS verification is requested.
861-
if !config.InsecureSkipVerify {
862-
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
863-
return nil, ocspErr
864-
}
865-
}
866-
return client, nil
867-
}
868-
869896
// OIDCTokenGenID returns the OIDC token generation ID.
870897
func (c *Connection) OIDCTokenGenID() uint64 {
871898
return c.oidcTokenGenID
@@ -919,11 +946,3 @@ func (c *cancellListener) StopListening() bool {
919946
c.done <- struct{}{}
920947
return c.aborted
921948
}
922-
923-
func (c *connection) OIDCTokenGenID() uint64 {
924-
return c.oidcTokenGenID
925-
}
926-
927-
func (c *connection) SetOIDCTokenGenID(genID uint64) {
928-
c.oidcTokenGenID = genID
929-
}

x/mongo/driver/topology/connection_test.go

+17
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,23 @@ func TestConnection(t *testing.T) {
546546
}
547547
listener.assertCalledOnce(t)
548548
})
549+
t.Run("size too small errors", func(t *testing.T) {
550+
err := errors.New("malformed message length: 3")
551+
tnc := &testNetConn{readerr: err, buf: []byte{0x03, 0x00, 0x00, 0x00}}
552+
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
553+
listener := newTestCancellationListener(false)
554+
conn.cancellationListener = listener
555+
556+
want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: err.Error()}
557+
_, got := conn.readWireMessage(context.Background())
558+
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
559+
t.Errorf("errors do not match. got %v; want %v", got, want)
560+
}
561+
if !tnc.closed {
562+
t.Errorf("failed to closeConnection net.Conn after error writing bytes.")
563+
}
564+
listener.assertCalledOnce(t)
565+
})
549566
t.Run("full message read errors", func(t *testing.T) {
550567
err := errors.New("Read error")
551568
tnc := &testNetConn{readerr: err, buf: []byte{0x11, 0x00, 0x00, 0x00}}

0 commit comments

Comments
 (0)