Skip to content

Commit 72d9da9

Browse files
shogo82148vmg
authored andcommitted
Implement Connector and DriverContext interface
1 parent df597a2 commit 72d9da9

File tree

4 files changed

+206
-106
lines changed

4 files changed

+206
-106
lines changed

connector.go

+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2+
//
3+
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
4+
//
5+
// This Source Code Form is subject to the terms of the Mozilla Public
6+
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7+
// You can obtain one at http://mozilla.org/MPL/2.0/.
8+
9+
package mysql
10+
11+
import (
12+
"context"
13+
"database/sql/driver"
14+
"net"
15+
)
16+
17+
type connector struct {
18+
cfg *Config // immutable private copy.
19+
}
20+
21+
// Connect implements driver.Connector interface.
22+
// Connect returns a connection to the database.
23+
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
24+
var err error
25+
26+
// New mysqlConn
27+
mc := &mysqlConn{
28+
maxAllowedPacket: maxPacketSize,
29+
maxWriteSize: maxPacketSize - 1,
30+
closech: make(chan struct{}),
31+
cfg: c.cfg,
32+
}
33+
mc.parseTime = mc.cfg.ParseTime
34+
35+
// Connect to Server
36+
// TODO: needs RegisterDialContext
37+
dialsLock.RLock()
38+
dial, ok := dials[mc.cfg.Net]
39+
dialsLock.RUnlock()
40+
if ok {
41+
mc.netConn, err = dial(mc.cfg.Addr)
42+
} else {
43+
nd := net.Dialer{Timeout: mc.cfg.Timeout}
44+
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
45+
}
46+
if err != nil {
47+
return nil, err
48+
}
49+
50+
// Enable TCP Keepalives on TCP connections
51+
if tc, ok := mc.netConn.(*net.TCPConn); ok {
52+
if err := tc.SetKeepAlive(true); err != nil {
53+
// Don't send COM_QUIT before handshake.
54+
mc.netConn.Close()
55+
mc.netConn = nil
56+
return nil, err
57+
}
58+
}
59+
60+
// Call startWatcher for context support (From Go 1.8)
61+
mc.startWatcher()
62+
if err := mc.watchCancel(ctx); err != nil {
63+
return nil, err
64+
}
65+
defer mc.finish()
66+
67+
mc.buf = newBuffer(mc.netConn)
68+
69+
// Set I/O timeouts
70+
mc.buf.timeout = mc.cfg.ReadTimeout
71+
mc.writeTimeout = mc.cfg.WriteTimeout
72+
73+
// Reading Handshake Initialization Packet
74+
authData, plugin, err := mc.readHandshakePacket()
75+
if err != nil {
76+
mc.cleanup()
77+
return nil, err
78+
}
79+
80+
if plugin == "" {
81+
plugin = defaultAuthPlugin
82+
}
83+
84+
// Send Client Authentication Packet
85+
authResp, addNUL, err := mc.auth(authData, plugin)
86+
if err != nil {
87+
// try the default auth plugin, if using the requested plugin failed
88+
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
89+
plugin = defaultAuthPlugin
90+
authResp, addNUL, err = mc.auth(authData, plugin)
91+
if err != nil {
92+
mc.cleanup()
93+
return nil, err
94+
}
95+
}
96+
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
97+
mc.cleanup()
98+
return nil, err
99+
}
100+
101+
// Handle response to auth packet, switch methods if possible
102+
if err = mc.handleAuthResult(authData, plugin); err != nil {
103+
// Authentication failed and MySQL has already closed the connection
104+
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
105+
// Do not send COM_QUIT, just cleanup and return the error.
106+
mc.cleanup()
107+
return nil, err
108+
}
109+
110+
if mc.cfg.MaxAllowedPacket > 0 {
111+
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
112+
} else {
113+
// Get max allowed packet size
114+
maxap, err := mc.getSystemVar("max_allowed_packet")
115+
if err != nil {
116+
mc.Close()
117+
return nil, err
118+
}
119+
mc.maxAllowedPacket = stringToInt(maxap) - 1
120+
}
121+
if mc.maxAllowedPacket < maxPacketSize {
122+
mc.maxWriteSize = mc.maxAllowedPacket
123+
}
124+
125+
// Handle DSN Params
126+
err = mc.handleParams()
127+
if err != nil {
128+
mc.Close()
129+
return nil, err
130+
}
131+
132+
return mc, nil
133+
}
134+
135+
// Driver implements driver.Connector interface.
136+
// Driver returns &MySQLDriver{}.
137+
func (c *connector) Driver() driver.Driver {
138+
return &MySQLDriver{}
139+
}

driver.go

+5-106
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package mysql
1818

1919
import (
20+
"context"
2021
"database/sql"
2122
"database/sql/driver"
2223
"net"
@@ -52,116 +53,14 @@ func RegisterDial(net string, dial DialFunc) {
5253
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
5354
// the DSN string is formatted
5455
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
55-
var err error
56-
57-
// New mysqlConn
58-
mc := &mysqlConn{
59-
maxAllowedPacket: maxPacketSize,
60-
maxWriteSize: maxPacketSize - 1,
61-
closech: make(chan struct{}),
62-
}
63-
mc.cfg, err = ParseDSN(dsn)
64-
if err != nil {
65-
return nil, err
66-
}
67-
mc.parseTime = mc.cfg.ParseTime
68-
69-
// Connect to Server
70-
dialsLock.RLock()
71-
dial, ok := dials[mc.cfg.Net]
72-
dialsLock.RUnlock()
73-
if ok {
74-
mc.netConn, err = dial(mc.cfg.Addr)
75-
} else {
76-
nd := net.Dialer{Timeout: mc.cfg.Timeout}
77-
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
78-
}
79-
if err != nil {
80-
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
81-
errLog.Print("net.Error from Dial()': ", nerr.Error())
82-
return nil, driver.ErrBadConn
83-
}
84-
return nil, err
85-
}
86-
87-
// Enable TCP Keepalives on TCP connections
88-
if tc, ok := mc.netConn.(*net.TCPConn); ok {
89-
if err := tc.SetKeepAlive(true); err != nil {
90-
// Don't send COM_QUIT before handshake.
91-
mc.netConn.Close()
92-
mc.netConn = nil
93-
return nil, err
94-
}
95-
}
96-
97-
// Call startWatcher for context support (From Go 1.8)
98-
mc.startWatcher()
99-
100-
mc.buf = newBuffer(mc.netConn)
101-
102-
// Set I/O timeouts
103-
mc.buf.timeout = mc.cfg.ReadTimeout
104-
mc.writeTimeout = mc.cfg.WriteTimeout
105-
106-
// Reading Handshake Initialization Packet
107-
authData, plugin, err := mc.readHandshakePacket()
56+
cfg, err := ParseDSN(dsn)
10857
if err != nil {
109-
mc.cleanup()
11058
return nil, err
11159
}
112-
if plugin == "" {
113-
plugin = defaultAuthPlugin
60+
c := &connector{
61+
cfg: cfg,
11462
}
115-
116-
// Send Client Authentication Packet
117-
authResp, err := mc.auth(authData, plugin)
118-
if err != nil {
119-
// try the default auth plugin, if using the requested plugin failed
120-
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
121-
plugin = defaultAuthPlugin
122-
authResp, err = mc.auth(authData, plugin)
123-
if err != nil {
124-
mc.cleanup()
125-
return nil, err
126-
}
127-
}
128-
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
129-
mc.cleanup()
130-
return nil, err
131-
}
132-
133-
// Handle response to auth packet, switch methods if possible
134-
if err = mc.handleAuthResult(authData, plugin); err != nil {
135-
// Authentication failed and MySQL has already closed the connection
136-
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
137-
// Do not send COM_QUIT, just cleanup and return the error.
138-
mc.cleanup()
139-
return nil, err
140-
}
141-
142-
if mc.cfg.MaxAllowedPacket > 0 {
143-
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
144-
} else {
145-
// Get max allowed packet size
146-
maxap, err := mc.getSystemVar("max_allowed_packet")
147-
if err != nil {
148-
mc.Close()
149-
return nil, err
150-
}
151-
mc.maxAllowedPacket = stringToInt(maxap) - 1
152-
}
153-
if mc.maxAllowedPacket < maxPacketSize {
154-
mc.maxWriteSize = mc.maxAllowedPacket
155-
}
156-
157-
// Handle DSN Params
158-
err = mc.handleParams()
159-
if err != nil {
160-
mc.Close()
161-
return nil, err
162-
}
163-
164-
return mc, nil
63+
return c.Connect(context.Background())
16564
}
16665

16766
func init() {

driver_go1.10.go

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2+
//
3+
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
4+
//
5+
// This Source Code Form is subject to the terms of the Mozilla Public
6+
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7+
// You can obtain one at http://mozilla.org/MPL/2.0/.
8+
9+
// +build go1.10
10+
11+
package mysql
12+
13+
import (
14+
"crypto/rsa"
15+
"database/sql/driver"
16+
"math/big"
17+
)
18+
19+
// NewConnector returns new driver.Connector.
20+
func NewConnector(cfg *Config) driver.Connector {
21+
copyCfg := *cfg
22+
copyCfg.tls = cfg.tls.Clone()
23+
copyCfg.Params = make(map[string]string, len(cfg.Params))
24+
for k, v := range cfg.Params {
25+
copyCfg.Params[k] = v
26+
}
27+
if cfg.pubKey != nil {
28+
copyCfg.pubKey = &rsa.PublicKey{
29+
N: new(big.Int).Set(cfg.pubKey.N),
30+
E: cfg.pubKey.E,
31+
}
32+
}
33+
return &connector{cfg: &copyCfg}
34+
}
35+
36+
// OpenConnector implements driver.DriverContext.
37+
func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
38+
cfg, err := ParseDSN(dsn)
39+
if err != nil {
40+
return nil, err
41+
}
42+
return &connector{
43+
cfg: cfg,
44+
}, nil
45+
}

driver_go1.10_test.go

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2+
//
3+
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
4+
//
5+
// This Source Code Form is subject to the terms of the Mozilla Public
6+
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7+
// You can obtain one at http://mozilla.org/MPL/2.0/.
8+
9+
// +build go1.10
10+
11+
package mysql
12+
13+
import (
14+
"database/sql/driver"
15+
)
16+
17+
var _ driver.DriverContext = &MySQLDriver{}

0 commit comments

Comments
 (0)