Skip to content

Commit 23f77c3

Browse files
committed
update the connection pool background read logic
1 parent be66ac6 commit 23f77c3

File tree

4 files changed

+268
-45
lines changed

4 files changed

+268
-45
lines changed

x/mongo/driver/topology/connection.go

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ type connection struct {
7979
driverConnectionID uint64
8080
generation uint64
8181

82-
// awaitingResponse indicates that the server response was not completely
82+
// awaitingResponse indicates the size of server response that was not completely
8383
// read before returning the connection to the pool.
84-
awaitingResponse bool
84+
awaitingResponse *int32
8585

8686
// oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate
8787
// accessTokens in the OIDC authenticator cache.
@@ -423,15 +423,10 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
423423

424424
dst, errMsg, err := c.read(ctx)
425425
if err != nil {
426-
if nerr := net.Error(nil); errors.As(err, &nerr) && nerr.Timeout() && csot.IsTimeoutContext(ctx) {
427-
// If the error was a timeout error and CSOT is enabled, instead of
428-
// closing the connection mark it as awaiting response so the pool
429-
// can read the response before making it available to other
430-
// operations.
431-
c.awaitingResponse = true
432-
} else {
433-
// Otherwise, use the pre-CSOT behavior and close the connection
434-
// because we don't know if there are other bytes left to read.
426+
if c.awaitingResponse == nil {
427+
// If the connection was not marked as awaiting response, use the
428+
// pre-CSOT behavior and close the connection because we don't know
429+
// if there are other bytes left to read.
435430
c.close()
436431
}
437432
message := errMsg
@@ -461,15 +456,27 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
461456
}
462457
}()
463458

459+
needToWait := func(err error) bool {
460+
// If the error was a timeout error and CSOT is enabled, instead of
461+
// closing the connection mark it as awaiting response so the pool
462+
// can read the response before making it available to other
463+
// operations.
464+
nerr := net.Error(nil)
465+
return errors.As(err, &nerr) && nerr.Timeout() && csot.IsTimeoutContext(ctx)
466+
}
467+
464468
// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
465469
// reslice dst once instead of twice.
466470
var sizeBuf [4]byte
467471

468472
// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
469473
// because there might be more than one wire message waiting to be read, for example when
470474
// reading messages from an exhaust cursor.
471-
_, err = io.ReadFull(c.nc, sizeBuf[:])
475+
n, err := io.ReadFull(c.nc, sizeBuf[:])
472476
if err != nil {
477+
if l := int32(n); l == 0 && needToWait(err) {
478+
c.awaitingResponse = &l
479+
}
473480
return nil, "incomplete read of message header", err
474481
}
475482

@@ -493,8 +500,11 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
493500
dst := make([]byte, size)
494501
copy(dst, sizeBuf[:])
495502

496-
_, err = io.ReadFull(c.nc, dst[4:])
503+
n, err = io.ReadFull(c.nc, dst[4:])
497504
if err != nil {
505+
if l := size - 4 - int32(n); l > 0 && needToWait(err) {
506+
c.awaitingResponse = &l
507+
}
498508
return dst, "incomplete read of full message", err
499509
}
500510

x/mongo/driver/topology/connection_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ func TestConnection(t *testing.T) {
547547
listener.assertCalledOnce(t)
548548
})
549549
t.Run("size too small errors", func(t *testing.T) {
550-
err := errors.New("malformatted message length: 3")
550+
err := errors.New("malformed message length: 3")
551551
tnc := &testNetConn{readerr: err, buf: []byte{0x03, 0x00, 0x00, 0x00}}
552552
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
553553
listener := newTestCancellationListener(false)

x/mongo/driver/topology/pool.go

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ package topology
99
import (
1010
"context"
1111
"fmt"
12+
"io"
13+
"io/ioutil"
1214
"net"
1315
"sync"
1416
"sync/atomic"
@@ -788,17 +790,27 @@ var (
788790
//
789791
// It calls the package-global BGReadCallback function, if set, with the
790792
// address, timings, and any errors that occurred.
791-
func bgRead(pool *pool, conn *connection) {
792-
var start, read time.Time
793-
start = time.Now()
794-
errs := make([]error, 0)
795-
connClosed := false
793+
func bgRead(pool *pool, conn *connection, size int32) {
794+
var err error
795+
start := time.Now()
796796

797797
defer func() {
798+
read := time.Now()
799+
errs := make([]error, 0)
800+
connClosed := false
801+
if err != nil {
802+
errs = append(errs, err)
803+
connClosed = true
804+
err = conn.close()
805+
if err != nil {
806+
errs = append(errs, fmt.Errorf("error closing conn after reading: %w", err))
807+
}
808+
}
809+
798810
// No matter what happens, always check the connection back into the
799811
// pool, which will either make it available for other operations or
800812
// remove it from the pool if it was closed.
801-
err := pool.checkInNoEvent(conn)
813+
err = pool.checkInNoEvent(conn)
802814
if err != nil {
803815
errs = append(errs, fmt.Errorf("error checking in: %w", err))
804816
}
@@ -808,34 +820,37 @@ func bgRead(pool *pool, conn *connection) {
808820
}
809821
}()
810822

811-
err := conn.nc.SetReadDeadline(time.Now().Add(BGReadTimeout))
823+
err = conn.nc.SetReadDeadline(time.Now().Add(BGReadTimeout))
812824
if err != nil {
813-
errs = append(errs, fmt.Errorf("error setting a read deadline: %w", err))
814-
815-
connClosed = true
816-
err := conn.close()
817-
if err != nil {
818-
errs = append(errs, fmt.Errorf("error closing conn after setting read deadline: %w", err))
819-
}
820-
825+
err = fmt.Errorf("error setting a read deadline: %w", err)
821826
return
822827
}
823828

824-
// The context here is only used for cancellation, not deadline timeout, so
825-
// use context.Background(). The read timeout is set by calling
826-
// SetReadDeadline above.
827-
_, _, err = conn.read(context.Background())
828-
read = time.Now()
829-
if err != nil {
830-
errs = append(errs, fmt.Errorf("error reading: %w", err))
831-
832-
connClosed = true
833-
err := conn.close()
829+
if size == 0 {
830+
var sizeBuf [4]byte
831+
_, err = io.ReadFull(conn.nc, sizeBuf[:])
834832
if err != nil {
835-
errs = append(errs, fmt.Errorf("error closing conn after reading: %w", err))
833+
err = fmt.Errorf("error reading the message size: %w", err)
834+
return
836835
}
837-
838-
return
836+
size = (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)
837+
if size < 4 {
838+
err = fmt.Errorf("malformed message length: %d", size)
839+
return
840+
}
841+
maxMessageSize := conn.desc.MaxMessageSize
842+
if maxMessageSize == 0 {
843+
maxMessageSize = defaultMaxMessageSize
844+
}
845+
if uint32(size) > maxMessageSize {
846+
err = errResponseTooLarge
847+
return
848+
}
849+
size -= 4
850+
}
851+
_, err = io.CopyN(ioutil.Discard, conn.nc, int64(size))
852+
if err != nil {
853+
err = fmt.Errorf("error reading message of %d: %w", size, err)
839854
}
840855
}
841856

@@ -886,9 +901,9 @@ func (p *pool) checkInNoEvent(conn *connection) error {
886901
// means that connections in "awaiting response" state are checked in but
887902
// not usable, which is not covered by the current pool events. We may need
888903
// to add pool event information in the future to communicate that.
889-
if conn.awaitingResponse {
890-
conn.awaitingResponse = false
891-
go bgRead(p, conn)
904+
if conn.awaitingResponse != nil {
905+
go bgRead(p, conn, *conn.awaitingResponse)
906+
conn.awaitingResponse = nil
892907
return nil
893908
}
894909

0 commit comments

Comments
 (0)