@@ -9,6 +9,7 @@ package topology
9
9
import (
10
10
"context"
11
11
"crypto/tls"
12
+ "encoding/binary"
12
13
"errors"
13
14
"fmt"
14
15
"io"
@@ -18,6 +19,7 @@ import (
18
19
"sync/atomic"
19
20
"time"
20
21
22
+ "go.mongodb.org/mongo-driver/v2/internal/csot"
21
23
"go.mongodb.org/mongo-driver/v2/internal/driverutil"
22
24
"go.mongodb.org/mongo-driver/v2/mongo/address"
23
25
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
@@ -80,9 +82,9 @@ type connection struct {
80
82
// accessTokens in the OIDC authenticator cache.
81
83
oidcTokenGenID uint64
82
84
83
- // awaitingResponse indicates that the server response was not completely
85
+ // awaitRemainingBytes indicates the size of server response that was not completely
84
86
// read before returning the connection to the pool.
85
- awaitingResponse bool
87
+ awaitRemainingBytes * int32
86
88
}
87
89
88
90
// newConnection handles the creation of a connection. It does not connect the connection.
@@ -111,11 +113,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
111
113
return c
112
114
}
113
115
114
- // DriverConnectionID returns the driver connection ID.
115
- func (c * connection ) DriverConnectionID () int64 {
116
- return c .driverConnectionID
117
- }
118
-
119
116
// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
120
117
// configuration.
121
118
func (c * connection ) setGenerationNumber () {
@@ -137,6 +134,39 @@ func (c *connection) hasGenerationNumber() bool {
137
134
return driverutil .IsServerLoadBalanced (c .desc )
138
135
}
139
136
137
+ func configureTLS (ctx context.Context ,
138
+ tlsConnSource tlsConnectionSource ,
139
+ nc net.Conn ,
140
+ addr address.Address ,
141
+ config * tls.Config ,
142
+ ocspOpts * ocsp.VerifyOptions ,
143
+ ) (net.Conn , error ) {
144
+ // Ensure config.ServerName is always set for SNI.
145
+ if config .ServerName == "" {
146
+ hostname := addr .String ()
147
+ colonPos := strings .LastIndex (hostname , ":" )
148
+ if colonPos == - 1 {
149
+ colonPos = len (hostname )
150
+ }
151
+
152
+ hostname = hostname [:colonPos ]
153
+ config .ServerName = hostname
154
+ }
155
+
156
+ client := tlsConnSource .Client (nc , config )
157
+ if err := clientHandshake (ctx , client ); err != nil {
158
+ return nil , err
159
+ }
160
+
161
+ // Only do OCSP verification if TLS verification is requested.
162
+ if ! config .InsecureSkipVerify {
163
+ if ocspErr := ocsp .Verify (ctx , client .ConnectionState (), ocspOpts ); ocspErr != nil {
164
+ return nil , ocspErr
165
+ }
166
+ }
167
+ return client , nil
168
+ }
169
+
140
170
// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
141
171
// handshakes. All errors returned by connect are considered "before the handshake completes" and
142
172
// must be handled by calling the appropriate SDAM handshake error handler.
@@ -291,6 +321,10 @@ func (c *connection) closeConnectContext() {
291
321
}
292
322
}
293
323
324
+ func (c * connection ) cancellationListenerCallback () {
325
+ _ = c .close ()
326
+ }
327
+
294
328
func transformNetworkError (ctx context.Context , originalError error , contextDeadlineUsed bool ) error {
295
329
if originalError == nil {
296
330
return nil
@@ -313,10 +347,6 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead
313
347
return originalError
314
348
}
315
349
316
- func (c * connection ) cancellationListenerCallback () {
317
- _ = c .close ()
318
- }
319
-
320
350
func (c * connection ) writeWireMessage (ctx context.Context , wm []byte ) error {
321
351
var err error
322
352
if atomic .LoadInt64 (& c .state ) != connConnected {
@@ -377,14 +407,10 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
377
407
378
408
dst , errMsg , err := c .read (ctx )
379
409
if err != nil {
380
- if nerr := net .Error (nil ); errors .As (err , & nerr ) && nerr .Timeout () {
381
- // If the error was a timeout error, instead of closing the
382
- // connection mark it as awaiting response so the pool can read the
383
- // response before making it available to other operations.
384
- c .awaitingResponse = true
385
- } else {
386
- // Otherwise, and close the connection because we don't know what
387
- // the connection state is.
410
+ if c .awaitRemainingBytes == nil {
411
+ // If the connection was not marked as awaiting response, use the
412
+ // pre-CSOT behavior and close the connection because we don't know
413
+ // if there are other bytes left to read.
388
414
c .close ()
389
415
}
390
416
message := errMsg
@@ -401,6 +427,26 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
401
427
return dst , nil
402
428
}
403
429
430
+ func (c * connection ) parseWmSizeBytes (wmSizeBytes [4 ]byte ) (int32 , error ) {
431
+ // read the length as an int32
432
+ size := int32 (binary .LittleEndian .Uint32 (wmSizeBytes [:]))
433
+
434
+ if size < 4 {
435
+ return 0 , fmt .Errorf ("malformed message length: %d" , size )
436
+ }
437
+ // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
438
+ // defaultMaxMessageSize instead.
439
+ maxMessageSize := c .desc .MaxMessageSize
440
+ if maxMessageSize == 0 {
441
+ maxMessageSize = defaultMaxMessageSize
442
+ }
443
+ if uint32 (size ) > maxMessageSize {
444
+ return 0 , errResponseTooLarge
445
+ }
446
+
447
+ return size , nil
448
+ }
449
+
404
450
func (c * connection ) read (ctx context.Context ) (bytesRead []byte , errMsg string , err error ) {
405
451
go c .cancellationListener .Listen (ctx , c .cancellationListenerCallback )
406
452
defer func () {
@@ -414,36 +460,43 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
414
460
}
415
461
}()
416
462
463
+ isCSOTTimeout := func (err error ) bool {
464
+ // If the error was a timeout error and CSOT is enabled, instead of
465
+ // closing the connection mark it as awaiting response so the pool
466
+ // can read the response before making it available to other
467
+ // operations.
468
+ nerr := net .Error (nil )
469
+ return errors .As (err , & nerr ) && nerr .Timeout () && csot .IsTimeoutContext (ctx )
470
+ }
471
+
417
472
// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
418
473
// reslice dst once instead of twice.
419
474
var sizeBuf [4 ]byte
420
475
421
476
// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
422
477
// because there might be more than one wire message waiting to be read, for example when
423
478
// reading messages from an exhaust cursor.
424
- _ , err = io .ReadFull (c .nc , sizeBuf [:])
479
+ n , err : = io .ReadFull (c .nc , sizeBuf [:])
425
480
if err != nil {
481
+ if l := int32 (n ); l == 0 && isCSOTTimeout (err ) {
482
+ c .awaitRemainingBytes = & l
483
+ }
426
484
return nil , "incomplete read of message header" , err
427
485
}
428
-
429
- // read the length as an int32
430
- size := (int32 (sizeBuf [0 ])) | (int32 (sizeBuf [1 ]) << 8 ) | (int32 (sizeBuf [2 ]) << 16 ) | (int32 (sizeBuf [3 ]) << 24 )
431
-
432
- // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
433
- // defaultMaxMessageSize instead.
434
- maxMessageSize := c .desc .MaxMessageSize
435
- if maxMessageSize == 0 {
436
- maxMessageSize = defaultMaxMessageSize
437
- }
438
- if uint32 (size ) > maxMessageSize {
439
- return nil , errResponseTooLarge .Error (), errResponseTooLarge
486
+ size , err := c .parseWmSizeBytes (sizeBuf )
487
+ if err != nil {
488
+ return nil , err .Error (), err
440
489
}
441
490
442
491
dst := make ([]byte , size )
443
492
copy (dst , sizeBuf [:])
444
493
445
- _ , err = io .ReadFull (c .nc , dst [4 :])
494
+ n , err = io .ReadFull (c .nc , dst [4 :])
446
495
if err != nil {
496
+ remainingBytes := size - 4 - int32 (n )
497
+ if remainingBytes > 0 && isCSOTTimeout (err ) {
498
+ c .awaitRemainingBytes = & remainingBytes
499
+ }
447
500
return dst , "incomplete read of full message" , err
448
501
}
449
502
@@ -496,10 +549,6 @@ func (c *connection) setCanStream(canStream bool) {
496
549
c .canStream = canStream
497
550
}
498
551
499
- func (c initConnection ) supportsStreaming () bool {
500
- return c .canStream
501
- }
502
-
503
552
func (c * connection ) setStreaming (streaming bool ) {
504
553
c .currentlyStreaming = streaming
505
554
}
@@ -508,6 +557,14 @@ func (c *connection) getCurrentlyStreaming() bool {
508
557
return c .currentlyStreaming
509
558
}
510
559
560
+ func (c * connection ) previousCanceled () bool {
561
+ if val := c .prevCanceled .Load (); val != nil {
562
+ return val .(bool )
563
+ }
564
+
565
+ return false
566
+ }
567
+
511
568
func (c * connection ) ID () string {
512
569
return c .id
513
570
}
@@ -516,12 +573,17 @@ func (c *connection) ServerConnectionID() *int64 {
516
573
return c .serverConnectionID
517
574
}
518
575
519
- func ( c * connection ) previousCanceled () bool {
520
- if val := c . prevCanceled . Load (); val != nil {
521
- return val .( bool )
522
- }
576
+ // DriverConnectionID returns the driver connection ID.
577
+ func ( c * connection ) DriverConnectionID () int64 {
578
+ return c . driverConnectionID
579
+ }
523
580
524
- return false
581
+ func (c * connection ) OIDCTokenGenID () uint64 {
582
+ return c .oidcTokenGenID
583
+ }
584
+
585
+ func (c * connection ) SetOIDCTokenGenID (genID uint64 ) {
586
+ c .oidcTokenGenID = genID
525
587
}
526
588
527
589
// initConnection is an adapter used during connection initialization. It has the minimum
@@ -562,7 +624,7 @@ func (c initConnection) CurrentlyStreaming() bool {
562
624
return c .getCurrentlyStreaming ()
563
625
}
564
626
func (c initConnection ) SupportsStreaming () bool {
565
- return c .supportsStreaming ()
627
+ return c .canStream
566
628
}
567
629
568
630
// Connection implements the driver.Connection interface to allow reading and writing wire
@@ -797,39 +859,6 @@ func (c *Connection) DriverConnectionID() int64 {
797
859
return c .connection .DriverConnectionID ()
798
860
}
799
861
800
- func configureTLS (ctx context.Context ,
801
- tlsConnSource tlsConnectionSource ,
802
- nc net.Conn ,
803
- addr address.Address ,
804
- config * tls.Config ,
805
- ocspOpts * ocsp.VerifyOptions ,
806
- ) (net.Conn , error ) {
807
- // Ensure config.ServerName is always set for SNI.
808
- if config .ServerName == "" {
809
- hostname := addr .String ()
810
- colonPos := strings .LastIndex (hostname , ":" )
811
- if colonPos == - 1 {
812
- colonPos = len (hostname )
813
- }
814
-
815
- hostname = hostname [:colonPos ]
816
- config .ServerName = hostname
817
- }
818
-
819
- client := tlsConnSource .Client (nc , config )
820
- if err := clientHandshake (ctx , client ); err != nil {
821
- return nil , err
822
- }
823
-
824
- // Only do OCSP verification if TLS verification is requested.
825
- if ! config .InsecureSkipVerify {
826
- if ocspErr := ocsp .Verify (ctx , client .ConnectionState (), ocspOpts ); ocspErr != nil {
827
- return nil , ocspErr
828
- }
829
- }
830
- return client , nil
831
- }
832
-
833
862
// OIDCTokenGenID returns the OIDC token generation ID.
834
863
func (c * Connection ) OIDCTokenGenID () uint64 {
835
864
return c .oidcTokenGenID
@@ -839,11 +868,3 @@ func (c *Connection) OIDCTokenGenID() uint64 {
839
868
func (c * Connection ) SetOIDCTokenGenID (genID uint64 ) {
840
869
c .oidcTokenGenID = genID
841
870
}
842
-
843
- func (c * connection ) OIDCTokenGenID () uint64 {
844
- return c .oidcTokenGenID
845
- }
846
-
847
- func (c * connection ) SetOIDCTokenGenID (genID uint64 ) {
848
- c .oidcTokenGenID = genID
849
- }
0 commit comments