@@ -9,6 +9,7 @@ package topology
9
9
import (
10
10
"context"
11
11
"crypto/tls"
12
+ "encoding/binary"
12
13
"errors"
13
14
"fmt"
14
15
"io"
@@ -79,9 +80,9 @@ type connection struct {
79
80
driverConnectionID uint64
80
81
generation uint64
81
82
82
- // awaitingResponse indicates that the server response was not completely
83
+ // awaitRemainingBytes indicates the size of server response that was not completely
83
84
// read before returning the connection to the pool.
84
- awaitingResponse bool
85
+ awaitRemainingBytes * int32
85
86
86
87
// oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate
87
88
// accessTokens in the OIDC authenticator cache.
@@ -115,12 +116,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
115
116
return c
116
117
}
117
118
118
- // DriverConnectionID returns the driver connection ID.
119
- // TODO(GODRIVER-2824): change return type to int64.
120
- func (c * connection ) DriverConnectionID () uint64 {
121
- return c .driverConnectionID
122
- }
123
-
124
119
// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
125
120
// configuration.
126
121
func (c * connection ) setGenerationNumber () {
@@ -142,6 +137,39 @@ func (c *connection) hasGenerationNumber() bool {
142
137
return c .desc .LoadBalanced ()
143
138
}
144
139
140
+ func configureTLS (ctx context.Context ,
141
+ tlsConnSource tlsConnectionSource ,
142
+ nc net.Conn ,
143
+ addr address.Address ,
144
+ config * tls.Config ,
145
+ ocspOpts * ocsp.VerifyOptions ,
146
+ ) (net.Conn , error ) {
147
+ // Ensure config.ServerName is always set for SNI.
148
+ if config .ServerName == "" {
149
+ hostname := addr .String ()
150
+ colonPos := strings .LastIndex (hostname , ":" )
151
+ if colonPos == - 1 {
152
+ colonPos = len (hostname )
153
+ }
154
+
155
+ hostname = hostname [:colonPos ]
156
+ config .ServerName = hostname
157
+ }
158
+
159
+ client := tlsConnSource .Client (nc , config )
160
+ if err := clientHandshake (ctx , client ); err != nil {
161
+ return nil , err
162
+ }
163
+
164
+ // Only do OCSP verification if TLS verification is requested.
165
+ if ! config .InsecureSkipVerify {
166
+ if ocspErr := ocsp .Verify (ctx , client .ConnectionState (), ocspOpts ); ocspErr != nil {
167
+ return nil , ocspErr
168
+ }
169
+ }
170
+ return client , nil
171
+ }
172
+
145
173
// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
146
174
// handshakes. All errors returned by connect are considered "before the handshake completes" and
147
175
// must be handled by calling the appropriate SDAM handshake error handler.
@@ -317,6 +345,10 @@ func (c *connection) closeConnectContext() {
317
345
}
318
346
}
319
347
348
+ func (c * connection ) cancellationListenerCallback () {
349
+ _ = c .close ()
350
+ }
351
+
320
352
func transformNetworkError (ctx context.Context , originalError error , contextDeadlineUsed bool ) error {
321
353
if originalError == nil {
322
354
return nil
@@ -339,10 +371,6 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead
339
371
return originalError
340
372
}
341
373
342
- func (c * connection ) cancellationListenerCallback () {
343
- _ = c .close ()
344
- }
345
-
346
374
func (c * connection ) writeWireMessage (ctx context.Context , wm []byte ) error {
347
375
var err error
348
376
if atomic .LoadInt64 (& c .state ) != connConnected {
@@ -423,15 +451,10 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
423
451
424
452
dst , errMsg , err := c .read (ctx )
425
453
if err != nil {
426
- if nerr := net .Error (nil ); errors .As (err , & nerr ) && nerr .Timeout () && csot .IsTimeoutContext (ctx ) {
427
- // If the error was a timeout error and CSOT is enabled, instead of
428
- // closing the connection mark it as awaiting response so the pool
429
- // can read the response before making it available to other
430
- // operations.
431
- c .awaitingResponse = true
432
- } else {
433
- // Otherwise, use the pre-CSOT behavior and close the connection
434
- // because we don't know if there are other bytes left to read.
454
+ if c .awaitRemainingBytes == nil {
455
+ // If the connection was not marked as awaiting response, use the
456
+ // pre-CSOT behavior and close the connection because we don't know
457
+ // if there are other bytes left to read.
435
458
c .close ()
436
459
}
437
460
message := errMsg
@@ -448,6 +471,26 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
448
471
return dst , nil
449
472
}
450
473
474
+ func (c * connection ) parseWmSizeBytes (wmSizeBytes [4 ]byte ) (int32 , error ) {
475
+ // read the length as an int32
476
+ size := int32 (binary .LittleEndian .Uint32 (wmSizeBytes [:]))
477
+
478
+ if size < 4 {
479
+ return 0 , fmt .Errorf ("malformed message length: %d" , size )
480
+ }
481
+ // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
482
+ // defaultMaxMessageSize instead.
483
+ maxMessageSize := c .desc .MaxMessageSize
484
+ if maxMessageSize == 0 {
485
+ maxMessageSize = defaultMaxMessageSize
486
+ }
487
+ if uint32 (size ) > maxMessageSize {
488
+ return 0 , errResponseTooLarge
489
+ }
490
+
491
+ return size , nil
492
+ }
493
+
451
494
func (c * connection ) read (ctx context.Context ) (bytesRead []byte , errMsg string , err error ) {
452
495
go c .cancellationListener .Listen (ctx , c .cancellationListenerCallback )
453
496
defer func () {
@@ -461,36 +504,43 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
461
504
}
462
505
}()
463
506
507
+ isCSOTTimeout := func (err error ) bool {
508
+ // If the error was a timeout error and CSOT is enabled, instead of
509
+ // closing the connection mark it as awaiting response so the pool
510
+ // can read the response before making it available to other
511
+ // operations.
512
+ nerr := net .Error (nil )
513
+ return errors .As (err , & nerr ) && nerr .Timeout () && csot .IsTimeoutContext (ctx )
514
+ }
515
+
464
516
// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
465
517
// reslice dst once instead of twice.
466
518
var sizeBuf [4 ]byte
467
519
468
520
// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
469
521
// because there might be more than one wire message waiting to be read, for example when
470
522
// reading messages from an exhaust cursor.
471
- _ , err = io .ReadFull (c .nc , sizeBuf [:])
523
+ n , err : = io .ReadFull (c .nc , sizeBuf [:])
472
524
if err != nil {
525
+ if l := int32 (n ); l == 0 && isCSOTTimeout (err ) {
526
+ c .awaitRemainingBytes = & l
527
+ }
473
528
return nil , "incomplete read of message header" , err
474
529
}
475
-
476
- // read the length as an int32
477
- size := (int32 (sizeBuf [0 ])) | (int32 (sizeBuf [1 ]) << 8 ) | (int32 (sizeBuf [2 ]) << 16 ) | (int32 (sizeBuf [3 ]) << 24 )
478
-
479
- // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
480
- // defaultMaxMessageSize instead.
481
- maxMessageSize := c .desc .MaxMessageSize
482
- if maxMessageSize == 0 {
483
- maxMessageSize = defaultMaxMessageSize
484
- }
485
- if uint32 (size ) > maxMessageSize {
486
- return nil , errResponseTooLarge .Error (), errResponseTooLarge
530
+ size , err := c .parseWmSizeBytes (sizeBuf )
531
+ if err != nil {
532
+ return nil , err .Error (), err
487
533
}
488
534
489
535
dst := make ([]byte , size )
490
536
copy (dst , sizeBuf [:])
491
537
492
- _ , err = io .ReadFull (c .nc , dst [4 :])
538
+ n , err = io .ReadFull (c .nc , dst [4 :])
493
539
if err != nil {
540
+ remainingBytes := size - 4 - int32 (n )
541
+ if remainingBytes > 0 && isCSOTTimeout (err ) {
542
+ c .awaitRemainingBytes = & remainingBytes
543
+ }
494
544
return dst , "incomplete read of full message" , err
495
545
}
496
546
@@ -537,10 +587,6 @@ func (c *connection) setCanStream(canStream bool) {
537
587
c .canStream = canStream
538
588
}
539
589
540
- func (c initConnection ) supportsStreaming () bool {
541
- return c .canStream
542
- }
543
-
544
590
func (c * connection ) setStreaming (streaming bool ) {
545
591
c .currentlyStreaming = streaming
546
592
}
@@ -554,6 +600,12 @@ func (c *connection) setSocketTimeout(timeout time.Duration) {
554
600
c .writeTimeout = timeout
555
601
}
556
602
603
+ // DriverConnectionID returns the driver connection ID.
604
+ // TODO(GODRIVER-2824): change return type to int64.
605
+ func (c * connection ) DriverConnectionID () uint64 {
606
+ return c .driverConnectionID
607
+ }
608
+
557
609
func (c * connection ) ID () string {
558
610
return c .id
559
611
}
@@ -562,6 +614,14 @@ func (c *connection) ServerConnectionID() *int64 {
562
614
return c .serverConnectionID
563
615
}
564
616
617
+ func (c * connection ) OIDCTokenGenID () uint64 {
618
+ return c .oidcTokenGenID
619
+ }
620
+
621
+ func (c * connection ) SetOIDCTokenGenID (genID uint64 ) {
622
+ c .oidcTokenGenID = genID
623
+ }
624
+
565
625
// initConnection is an adapter used during connection initialization. It has the minimum
566
626
// functionality necessary to implement the driver.Connection interface, which is required to pass a
567
627
// *connection to a Handshaker.
@@ -599,7 +659,7 @@ func (c initConnection) CurrentlyStreaming() bool {
599
659
return c .getCurrentlyStreaming ()
600
660
}
601
661
func (c initConnection ) SupportsStreaming () bool {
602
- return c .supportsStreaming ()
662
+ return c .canStream
603
663
}
604
664
605
665
// Connection implements the driver.Connection interface to allow reading and writing wire
@@ -833,39 +893,6 @@ func (c *Connection) DriverConnectionID() uint64 {
833
893
return c .connection .DriverConnectionID ()
834
894
}
835
895
836
- func configureTLS (ctx context.Context ,
837
- tlsConnSource tlsConnectionSource ,
838
- nc net.Conn ,
839
- addr address.Address ,
840
- config * tls.Config ,
841
- ocspOpts * ocsp.VerifyOptions ,
842
- ) (net.Conn , error ) {
843
- // Ensure config.ServerName is always set for SNI.
844
- if config .ServerName == "" {
845
- hostname := addr .String ()
846
- colonPos := strings .LastIndex (hostname , ":" )
847
- if colonPos == - 1 {
848
- colonPos = len (hostname )
849
- }
850
-
851
- hostname = hostname [:colonPos ]
852
- config .ServerName = hostname
853
- }
854
-
855
- client := tlsConnSource .Client (nc , config )
856
- if err := clientHandshake (ctx , client ); err != nil {
857
- return nil , err
858
- }
859
-
860
- // Only do OCSP verification if TLS verification is requested.
861
- if ! config .InsecureSkipVerify {
862
- if ocspErr := ocsp .Verify (ctx , client .ConnectionState (), ocspOpts ); ocspErr != nil {
863
- return nil , ocspErr
864
- }
865
- }
866
- return client , nil
867
- }
868
-
869
896
// OIDCTokenGenID returns the OIDC token generation ID.
870
897
func (c * Connection ) OIDCTokenGenID () uint64 {
871
898
return c .oidcTokenGenID
@@ -919,11 +946,3 @@ func (c *cancellListener) StopListening() bool {
919
946
c .done <- struct {}{}
920
947
return c .aborted
921
948
}
922
-
923
- func (c * connection ) OIDCTokenGenID () uint64 {
924
- return c .oidcTokenGenID
925
- }
926
-
927
- func (c * connection ) SetOIDCTokenGenID (genID uint64 ) {
928
- c .oidcTokenGenID = genID
929
- }
0 commit comments