@@ -9,6 +9,7 @@ package topology
9
9
import (
10
10
"context"
11
11
"crypto/tls"
12
+ "encoding/binary"
12
13
"errors"
13
14
"fmt"
14
15
"io"
@@ -80,9 +81,9 @@ type connection struct {
80
81
// accessTokens in the OIDC authenticator cache.
81
82
oidcTokenGenID uint64
82
83
83
- // awaitingResponse indicates that the server response was not completely
84
+ // awaitRemainingBytes indicates the size of server response that was not completely
84
85
// read before returning the connection to the pool.
85
- awaitingResponse bool
86
+ awaitRemainingBytes * int32
86
87
}
87
88
88
89
// newConnection handles the creation of a connection. It does not connect the connection.
@@ -111,11 +112,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
111
112
return c
112
113
}
113
114
114
- // DriverConnectionID returns the driver connection ID.
115
- func (c * connection ) DriverConnectionID () int64 {
116
- return c .driverConnectionID
117
- }
118
-
119
115
// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
120
116
// configuration.
121
117
func (c * connection ) setGenerationNumber () {
@@ -137,6 +133,39 @@ func (c *connection) hasGenerationNumber() bool {
137
133
return driverutil .IsServerLoadBalanced (c .desc )
138
134
}
139
135
136
+ func configureTLS (ctx context.Context ,
137
+ tlsConnSource tlsConnectionSource ,
138
+ nc net.Conn ,
139
+ addr address.Address ,
140
+ config * tls.Config ,
141
+ ocspOpts * ocsp.VerifyOptions ,
142
+ ) (net.Conn , error ) {
143
+ // Ensure config.ServerName is always set for SNI.
144
+ if config .ServerName == "" {
145
+ hostname := addr .String ()
146
+ colonPos := strings .LastIndex (hostname , ":" )
147
+ if colonPos == - 1 {
148
+ colonPos = len (hostname )
149
+ }
150
+
151
+ hostname = hostname [:colonPos ]
152
+ config .ServerName = hostname
153
+ }
154
+
155
+ client := tlsConnSource .Client (nc , config )
156
+ if err := clientHandshake (ctx , client ); err != nil {
157
+ return nil , err
158
+ }
159
+
160
+ // Only do OCSP verification if TLS verification is requested.
161
+ if ! config .InsecureSkipVerify {
162
+ if ocspErr := ocsp .Verify (ctx , client .ConnectionState (), ocspOpts ); ocspErr != nil {
163
+ return nil , ocspErr
164
+ }
165
+ }
166
+ return client , nil
167
+ }
168
+
140
169
// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
141
170
// handshakes. All errors returned by connect are considered "before the handshake completes" and
142
171
// must be handled by calling the appropriate SDAM handshake error handler.
@@ -291,6 +320,10 @@ func (c *connection) closeConnectContext() {
291
320
}
292
321
}
293
322
323
+ func (c * connection ) cancellationListenerCallback () {
324
+ _ = c .close ()
325
+ }
326
+
294
327
func transformNetworkError (ctx context.Context , originalError error , contextDeadlineUsed bool ) error {
295
328
if originalError == nil {
296
329
return nil
@@ -313,10 +346,6 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead
313
346
return originalError
314
347
}
315
348
316
- func (c * connection ) cancellationListenerCallback () {
317
- _ = c .close ()
318
- }
319
-
320
349
func (c * connection ) writeWireMessage (ctx context.Context , wm []byte ) error {
321
350
var err error
322
351
if atomic .LoadInt64 (& c .state ) != connConnected {
@@ -377,14 +406,9 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
377
406
378
407
dst , errMsg , err := c .read (ctx )
379
408
if err != nil {
380
- if nerr := net .Error (nil ); errors .As (err , & nerr ) && nerr .Timeout () {
381
- // If the error was a timeout error, instead of closing the
382
- // connection mark it as awaiting response so the pool can read the
383
- // response before making it available to other operations.
384
- c .awaitingResponse = true
385
- } else {
386
- // Otherwise, and close the connection because we don't know what
387
- // the connection state is.
409
+ if c .awaitRemainingBytes == nil {
410
+ // If the connection was not marked as awaiting response, close the
411
+ // connection because we don't know what the connection state is.
388
412
c .close ()
389
413
}
390
414
message := errMsg
@@ -401,6 +425,26 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
401
425
return dst , nil
402
426
}
403
427
428
+ func (c * connection ) parseWmSizeBytes (wmSizeBytes [4 ]byte ) (int32 , error ) {
429
+ // read the length as an int32
430
+ size := int32 (binary .LittleEndian .Uint32 (wmSizeBytes [:]))
431
+
432
+ if size < 4 {
433
+ return 0 , fmt .Errorf ("malformed message length: %d" , size )
434
+ }
435
+ // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
436
+ // defaultMaxMessageSize instead.
437
+ maxMessageSize := c .desc .MaxMessageSize
438
+ if maxMessageSize == 0 {
439
+ maxMessageSize = defaultMaxMessageSize
440
+ }
441
+ if uint32 (size ) > maxMessageSize {
442
+ return 0 , errResponseTooLarge
443
+ }
444
+
445
+ return size , nil
446
+ }
447
+
404
448
func (c * connection ) read (ctx context.Context ) (bytesRead []byte , errMsg string , err error ) {
405
449
go c .cancellationListener .Listen (ctx , c .cancellationListenerCallback )
406
450
defer func () {
@@ -414,36 +458,42 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
414
458
}
415
459
}()
416
460
461
+ isCSOTTimeout := func (err error ) bool {
462
+ // If the error was a timeout error, instead of closing the
463
+ // connection mark it as awaiting response so the pool can read the
464
+ // response before making it available to other operations.
465
+ nerr := net .Error (nil )
466
+ return errors .As (err , & nerr ) && nerr .Timeout ()
467
+ }
468
+
417
469
// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
418
470
// reslice dst once instead of twice.
419
471
var sizeBuf [4 ]byte
420
472
421
473
// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
422
474
// because there might be more than one wire message waiting to be read, for example when
423
475
// reading messages from an exhaust cursor.
424
- _ , err = io .ReadFull (c .nc , sizeBuf [:])
476
+ n , err : = io .ReadFull (c .nc , sizeBuf [:])
425
477
if err != nil {
478
+ if l := int32 (n ); l == 0 && isCSOTTimeout (err ) {
479
+ c .awaitRemainingBytes = & l
480
+ }
426
481
return nil , "incomplete read of message header" , err
427
482
}
428
-
429
- // read the length as an int32
430
- size := (int32 (sizeBuf [0 ])) | (int32 (sizeBuf [1 ]) << 8 ) | (int32 (sizeBuf [2 ]) << 16 ) | (int32 (sizeBuf [3 ]) << 24 )
431
-
432
- // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
433
- // defaultMaxMessageSize instead.
434
- maxMessageSize := c .desc .MaxMessageSize
435
- if maxMessageSize == 0 {
436
- maxMessageSize = defaultMaxMessageSize
437
- }
438
- if uint32 (size ) > maxMessageSize {
439
- return nil , errResponseTooLarge .Error (), errResponseTooLarge
483
+ size , err := c .parseWmSizeBytes (sizeBuf )
484
+ if err != nil {
485
+ return nil , err .Error (), err
440
486
}
441
487
442
488
dst := make ([]byte , size )
443
489
copy (dst , sizeBuf [:])
444
490
445
- _ , err = io .ReadFull (c .nc , dst [4 :])
491
+ n , err = io .ReadFull (c .nc , dst [4 :])
446
492
if err != nil {
493
+ remainingBytes := size - 4 - int32 (n )
494
+ if remainingBytes > 0 && isCSOTTimeout (err ) {
495
+ c .awaitRemainingBytes = & remainingBytes
496
+ }
447
497
return dst , "incomplete read of full message" , err
448
498
}
449
499
@@ -496,10 +546,6 @@ func (c *connection) setCanStream(canStream bool) {
496
546
c .canStream = canStream
497
547
}
498
548
499
- func (c initConnection ) supportsStreaming () bool {
500
- return c .canStream
501
- }
502
-
503
549
func (c * connection ) setStreaming (streaming bool ) {
504
550
c .currentlyStreaming = streaming
505
551
}
@@ -508,6 +554,14 @@ func (c *connection) getCurrentlyStreaming() bool {
508
554
return c .currentlyStreaming
509
555
}
510
556
557
+ func (c * connection ) previousCanceled () bool {
558
+ if val := c .prevCanceled .Load (); val != nil {
559
+ return val .(bool )
560
+ }
561
+
562
+ return false
563
+ }
564
+
511
565
func (c * connection ) ID () string {
512
566
return c .id
513
567
}
@@ -516,12 +570,17 @@ func (c *connection) ServerConnectionID() *int64 {
516
570
return c .serverConnectionID
517
571
}
518
572
519
- func ( c * connection ) previousCanceled () bool {
520
- if val := c . prevCanceled . Load (); val != nil {
521
- return val .( bool )
522
- }
573
+ // DriverConnectionID returns the driver connection ID.
574
+ func ( c * connection ) DriverConnectionID () int64 {
575
+ return c . driverConnectionID
576
+ }
523
577
524
- return false
578
+ func (c * connection ) OIDCTokenGenID () uint64 {
579
+ return c .oidcTokenGenID
580
+ }
581
+
582
+ func (c * connection ) SetOIDCTokenGenID (genID uint64 ) {
583
+ c .oidcTokenGenID = genID
525
584
}
526
585
527
586
// initConnection is an adapter used during connection initialization. It has the minimum
@@ -562,7 +621,7 @@ func (c initConnection) CurrentlyStreaming() bool {
562
621
return c .getCurrentlyStreaming ()
563
622
}
564
623
func (c initConnection ) SupportsStreaming () bool {
565
- return c .supportsStreaming ()
624
+ return c .canStream
566
625
}
567
626
568
627
// Connection implements the driver.Connection interface to allow reading and writing wire
@@ -797,39 +856,6 @@ func (c *Connection) DriverConnectionID() int64 {
797
856
return c .connection .DriverConnectionID ()
798
857
}
799
858
800
- func configureTLS (ctx context.Context ,
801
- tlsConnSource tlsConnectionSource ,
802
- nc net.Conn ,
803
- addr address.Address ,
804
- config * tls.Config ,
805
- ocspOpts * ocsp.VerifyOptions ,
806
- ) (net.Conn , error ) {
807
- // Ensure config.ServerName is always set for SNI.
808
- if config .ServerName == "" {
809
- hostname := addr .String ()
810
- colonPos := strings .LastIndex (hostname , ":" )
811
- if colonPos == - 1 {
812
- colonPos = len (hostname )
813
- }
814
-
815
- hostname = hostname [:colonPos ]
816
- config .ServerName = hostname
817
- }
818
-
819
- client := tlsConnSource .Client (nc , config )
820
- if err := clientHandshake (ctx , client ); err != nil {
821
- return nil , err
822
- }
823
-
824
- // Only do OCSP verification if TLS verification is requested.
825
- if ! config .InsecureSkipVerify {
826
- if ocspErr := ocsp .Verify (ctx , client .ConnectionState (), ocspOpts ); ocspErr != nil {
827
- return nil , ocspErr
828
- }
829
- }
830
- return client , nil
831
- }
832
-
833
859
// OIDCTokenGenID returns the OIDC token generation ID.
834
860
func (c * Connection ) OIDCTokenGenID () uint64 {
835
861
return c .oidcTokenGenID
@@ -839,11 +865,3 @@ func (c *Connection) OIDCTokenGenID() uint64 {
839
865
func (c * Connection ) SetOIDCTokenGenID (genID uint64 ) {
840
866
c .oidcTokenGenID = genID
841
867
}
842
-
843
- func (c * connection ) OIDCTokenGenID () uint64 {
844
- return c .oidcTokenGenID
845
- }
846
-
847
- func (c * connection ) SetOIDCTokenGenID (genID uint64 ) {
848
- c .oidcTokenGenID = genID
849
- }
0 commit comments