Skip to content

Commit 3ce2991

Browse files
committed
connector: Rebase on top of master
1 parent 72d9da9 commit 3ce2991

8 files changed

+109
-29
lines changed

appengine.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@
1111
package mysql
1212

1313
import (
14+
"context"
15+
1416
"google.golang.org/appengine/cloudsql"
1517
)
1618

1719
func init() {
18-
RegisterDial("cloudsql", cloudsql.Dial)
20+
RegisterDialContext("cloudsql", func(_ context.Context, instance addr) (net.Conn, error) {
21+
// XXX: the cloudsql driver still does not export a Context-aware dialer.
22+
return cloudsql.Dial(instance)
23+
})
1924
}

connector.go

+9-5
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,21 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
3333
mc.parseTime = mc.cfg.ParseTime
3434

3535
// Connect to Server
36-
// TODO: needs RegisterDialContext
3736
dialsLock.RLock()
3837
dial, ok := dials[mc.cfg.Net]
3938
dialsLock.RUnlock()
4039
if ok {
41-
mc.netConn, err = dial(mc.cfg.Addr)
40+
mc.netConn, err = dial(ctx, mc.cfg.Addr)
4241
} else {
4342
nd := net.Dialer{Timeout: mc.cfg.Timeout}
4443
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
4544
}
45+
4646
if err != nil {
47+
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
48+
errLog.Print("net.Error from Dial()': ", nerr.Error())
49+
return nil, driver.ErrBadConn
50+
}
4751
return nil, err
4852
}
4953

@@ -82,18 +86,18 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
8286
}
8387

8488
// Send Client Authentication Packet
85-
authResp, addNUL, err := mc.auth(authData, plugin)
89+
authResp, err := mc.auth(authData, plugin)
8690
if err != nil {
8791
// try the default auth plugin, if using the requested plugin failed
8892
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
8993
plugin = defaultAuthPlugin
90-
authResp, addNUL, err = mc.auth(authData, plugin)
94+
authResp, err = mc.auth(authData, plugin)
9195
if err != nil {
9296
mc.cleanup()
9397
return nil, err
9498
}
9599
}
96-
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
100+
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
97101
mc.cleanup()
98102
return nil, err
99103
}

driver.go

+22-5
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,42 @@ type MySQLDriver struct{}
3030

3131
// DialFunc is a function which can be used to establish the network connection.
3232
// Custom dial functions must be registered with RegisterDial
33+
//
34+
// Deprecated: users should register a DialContextFunc instead
3335
type DialFunc func(addr string) (net.Conn, error)
3436

37+
// DialContextFunc is a function which can be used to establish the network connection.
38+
// Custom dial functions must be registered with RegisterDialContext
39+
type DialContextFunc func(ctx context.Context, addr string) (net.Conn, error)
40+
3541
var (
3642
dialsLock sync.RWMutex
37-
dials map[string]DialFunc
43+
dials map[string]DialContextFunc
3844
)
3945

40-
// RegisterDial registers a custom dial function. It can then be used by the
46+
// RegisterDialContext registers a custom dial function. It can then be used by the
4147
// network address mynet(addr), where mynet is the registered new network.
42-
// addr is passed as a parameter to the dial function.
43-
func RegisterDial(net string, dial DialFunc) {
48+
// The current context for the connection and its address is passed to the dial function.
49+
func RegisterDialContext(net string, dial DialContextFunc) {
4450
dialsLock.Lock()
4551
defer dialsLock.Unlock()
4652
if dials == nil {
47-
dials = make(map[string]DialFunc)
53+
dials = make(map[string]DialContextFunc)
4854
}
4955
dials[net] = dial
5056
}
5157

58+
// RegisterDial registers a custom dial function. It can then be used by the
59+
// network address mynet(addr), where mynet is the registered new network.
60+
// addr is passed as a parameter to the dial function.
61+
//
62+
// Deprecated: users should call RegisterDialContext instead
63+
func RegisterDial(network string, dial DialFunc) {
64+
RegisterDialContext(network, func(_ context.Context, addr string) (net.Conn, error) {
65+
return dial(addr)
66+
})
67+
}
68+
5269
// Open new Connection.
5370
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
5471
// the DSN string is formatted

driver_go1.10.go renamed to driver_go110.go

+7-15
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,18 @@
1111
package mysql
1212

1313
import (
14-
"crypto/rsa"
1514
"database/sql/driver"
16-
"math/big"
1715
)
1816

1917
// NewConnector returns new driver.Connector.
20-
func NewConnector(cfg *Config) driver.Connector {
21-
copyCfg := *cfg
22-
copyCfg.tls = cfg.tls.Clone()
23-
copyCfg.Params = make(map[string]string, len(cfg.Params))
24-
for k, v := range cfg.Params {
25-
copyCfg.Params[k] = v
26-
}
27-
if cfg.pubKey != nil {
28-
copyCfg.pubKey = &rsa.PublicKey{
29-
N: new(big.Int).Set(cfg.pubKey.N),
30-
E: cfg.pubKey.E,
31-
}
18+
func NewConnector(cfg *Config) (driver.Connector, error) {
19+
cfg = cfg.Clone()
20+
// normalize the contents of cfg so calls to NewConnector have the same
21+
// behavior as MySQLDriver.OpenConnector
22+
if err := cfg.normalize(); err != nil {
23+
return nil, err
3224
}
33-
return &connector{cfg: &copyCfg}
25+
return &connector{cfg: cfg}, nil
3426
}
3527

3628
// OpenConnector implements driver.DriverContext.
File renamed without changes.

driver_test.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -1846,7 +1846,7 @@ func TestConcurrent(t *testing.T) {
18461846
}
18471847

18481848
func testDialError(t *testing.T, dialErr error, expectErr error) {
1849-
RegisterDial("mydial", func(addr string) (net.Conn, error) {
1849+
RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) {
18501850
return nil, dialErr
18511851
})
18521852

@@ -1884,8 +1884,9 @@ func TestCustomDial(t *testing.T) {
18841884
}
18851885

18861886
// our custom dial function which justs wraps net.Dial here
1887-
RegisterDial("mydial", func(addr string) (net.Conn, error) {
1888-
return net.Dial(prot, addr)
1887+
RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) {
1888+
var d net.Dialer
1889+
return d.DialContext(ctx, prot, addr)
18891890
})
18901891

18911892
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))

dsn.go

+21
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"crypto/tls"
1515
"errors"
1616
"fmt"
17+
"math/big"
1718
"net"
1819
"net/url"
1920
"sort"
@@ -72,6 +73,26 @@ func NewConfig() *Config {
7273
}
7374
}
7475

76+
func (cfg *Config) Clone() *Config {
77+
cp := *cfg
78+
if cp.tls != nil {
79+
cp.tls = cfg.tls.Clone()
80+
}
81+
if len(cp.Params) > 0 {
82+
cp.Params = make(map[string]string, len(cfg.Params))
83+
for k, v := range cfg.Params {
84+
cp.Params[k] = v
85+
}
86+
}
87+
if cfg.pubKey != nil {
88+
cp.pubKey = &rsa.PublicKey{
89+
N: new(big.Int).Set(cfg.pubKey.N),
90+
E: cfg.pubKey.E,
91+
}
92+
}
93+
return &cp
94+
}
95+
7596
func (cfg *Config) normalize() error {
7697
if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
7798
return errInvalidDSNUnsafeCollation

dsn_test.go

+40
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,46 @@ func TestParamsAreSorted(t *testing.T) {
318318
}
319319
}
320320

321+
func TestCloneConfig(t *testing.T) {
322+
RegisterServerPubKey("testKey", testPubKeyRSA)
323+
defer DeregisterServerPubKey("testKey")
324+
325+
expectedServerName := "example.com"
326+
dsn := "tcp(example.com:1234)/?tls=true&foobar=baz&serverPubKey=testKey"
327+
cfg, err := ParseDSN(dsn)
328+
if err != nil {
329+
t.Fatal(err.Error())
330+
}
331+
332+
cfg2 := cfg.Clone()
333+
if cfg == cfg2 {
334+
t.Errorf("Config.Clone did not create a separate config struct")
335+
}
336+
337+
if cfg2.tls.ServerName != expectedServerName {
338+
t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName)
339+
}
340+
341+
cfg2.tls.ServerName = "example2.com"
342+
if cfg.tls.ServerName == cfg2.tls.ServerName {
343+
t.Errorf("changed cfg.tls.Server name should not propagate to original Config")
344+
}
345+
346+
if _, ok := cfg2.Params["foobar"]; !ok {
347+
t.Errorf("cloned Config is missing custom params")
348+
}
349+
350+
delete(cfg2.Params, "foobar")
351+
352+
if _, ok := cfg.Params["foobar"]; !ok {
353+
t.Errorf("custom params in cloned Config should not propagate to original Config")
354+
}
355+
356+
if !reflect.DeepEqual(cfg.pubKey, cfg2.pubKey) {
357+
t.Errorf("public key in Config should be identical")
358+
}
359+
}
360+
321361
func BenchmarkParseDSN(b *testing.B) {
322362
b.ReportAllocs()
323363

0 commit comments

Comments
 (0)