Skip to content

Commit 9340cdc

Browse files
committed
mssql: support the connection session resetter interface
The bulk table load uses a per-session temp table. Pull out a single connection to fix the test. Fixes #329
1 parent 6a30f4e commit 9340cdc

File tree

11 files changed

+100
-46
lines changed

11 files changed

+100
-46
lines changed

buf.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,23 @@ func (w *tdsBuffer) WriteByte(b byte) error {
115115
return nil
116116
}
117117

118-
func (w *tdsBuffer) BeginPacket(packetType packetType) {
119-
w.wbuf[1] = 0 // Packet is incomplete. This byte is set again in FinishPacket.
118+
func (w *tdsBuffer) BeginPacket(packetType packetType, resetSession bool) {
119+
status := byte(0)
120+
if resetSession {
121+
switch packetType {
122+
// Reset session can only be set on the following packet types.
123+
case packSQLBatch, packRPCRequest, packTransMgrReq:
124+
status = 0x8
125+
}
126+
}
127+
w.wbuf[1] = status // Packet is incomplete. This byte is set again in FinishPacket.
120128
w.wpos = 8
121129
w.wPacketSeq = 1
122130
w.wPacketType = packetType
123131
}
124132

125133
func (w *tdsBuffer) FinishPacket() error {
126-
w.wbuf[1] = 1 // Mark this as the last packet in the message.
134+
w.wbuf[1] |= 1 // Mark this as the last packet in the message.
127135
return w.flush()
128136
}
129137

buf_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ func TestReadFailsOnSecondPacket(t *testing.T) {
151151
func TestWrite(t *testing.T) {
152152
memBuf := bytes.NewBuffer([]byte{})
153153
buf := newTdsBuffer(11, closableBuffer{memBuf})
154-
buf.BeginPacket(1)
154+
buf.BeginPacket(1, false)
155155
err := buf.WriteByte(2)
156156
if err != nil {
157157
t.Fatal("WriteByte failed:", err.Error())
@@ -172,7 +172,7 @@ func TestWrite(t *testing.T) {
172172
t.Fatalf("Written buffer has invalid content: %v", memBuf.Bytes())
173173
}
174174

175-
buf.BeginPacket(2)
175+
buf.BeginPacket(2, false)
176176
wrote, err = buf.Write([]byte{3, 4, 5, 6})
177177
if err != nil {
178178
t.Fatal("Write failed:", err.Error())

bulkcopy.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ func (b *Bulk) sendBulkCommand() (err error) {
128128
b.headerSent = true
129129

130130
var buf = b.cn.sess.buf
131-
buf.BeginPacket(packBulkLoadBCP)
131+
buf.BeginPacket(packBulkLoadBCP, false)
132132

133133
// send the columns metadata
134134
columnMetadata := b.createColMetadata()

bulkcopy_test.go

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
// +build go1.9
2+
13
package mssql
24

35
import (
6+
"context"
47
"database/sql"
58
"encoding/hex"
69
"log"
@@ -20,6 +23,7 @@ func TestBulkcopy(t *testing.T) {
2023
colname string
2124
val interface{}
2225
}
26+
2327
tableName := "#table_test"
2428
geom, _ := hex.DecodeString("E6100000010C00000000000034400000000000004440")
2529
testValues := []testValue{
@@ -71,18 +75,30 @@ func TestBulkcopy(t *testing.T) {
7175
values[i] = val.val
7276
}
7377

74-
conn := open(t)
78+
pool := open(t)
79+
defer pool.Close()
80+
81+
ctx, cancel := context.WithCancel(context.Background())
82+
defer cancel()
83+
84+
// Now that session resetting is supported, the use of the per session
85+
// temp table requires the use of a dedicated connection from the connection
86+
// pool.
87+
conn, err := pool.Conn(ctx)
88+
if err != nil {
89+
t.Error("failed to pull connection from pool", err)
90+
}
7591
defer conn.Close()
7692

77-
err := setupTable(conn, tableName)
93+
err = setupTable(ctx, conn, tableName)
7894
if err != nil {
79-
t.Error("Setup table failed: ", err.Error())
95+
t.Error("Setup table failed: ", err)
8096
return
8197
}
8298

8399
log.Println("Preparing copyin statement")
84100

85-
stmt, err := conn.Prepare(CopyIn(tableName, BulkOptions{}, columns...))
101+
stmt, err := conn.PrepareContext(ctx, CopyIn(tableName, BulkOptions{}, columns...))
86102

87103
for i := 0; i < 10; i++ {
88104
log.Printf("Executing copy in statement %d time with %d values", i+1, len(values))
@@ -105,14 +121,14 @@ func TestBulkcopy(t *testing.T) {
105121

106122
//check that all rows are present
107123
var rowCount int
108-
err = conn.QueryRow("select count(*) c from " + tableName).Scan(&rowCount)
124+
err = conn.QueryRowContext(ctx, "select count(*) c from "+tableName).Scan(&rowCount)
109125

110126
if rowCount != 10 {
111127
t.Errorf("unexpected row count %d", rowCount)
112128
}
113129

114130
//data verification
115-
rows, err := conn.Query("select " + strings.Join(columns, ",") + " from " + tableName)
131+
rows, err := conn.QueryContext(ctx, "select "+strings.Join(columns, ",")+" from "+tableName)
116132
if err != nil {
117133
log.Fatal(err)
118134
}
@@ -158,7 +174,7 @@ func compareValue(a interface{}, expected interface{}) bool {
158174
}
159175
}
160176

161-
func setupTable(conn *sql.DB, tableName string) (err error) {
177+
func setupTable(ctx context.Context, conn *sql.Conn, tableName string) (err error) {
162178
tablesql := `CREATE TABLE ` + tableName + ` (
163179
[id] [int] IDENTITY(1,1) NOT NULL,
164180
[test_nvarchar] [nvarchar](50) NULL,
@@ -203,7 +219,7 @@ func setupTable(conn *sql.DB, tableName string) (err error) {
203219
[id] ASC
204220
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
205221
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY];`
206-
_, err = conn.Exec(tablesql)
222+
_, err = conn.ExecContext(ctx, tablesql)
207223
if err != nil {
208224
log.Fatal("tablesql failed:", err)
209225
}

mssql.go

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,23 @@ func (d *Driver) SetLogger(logger Logger) {
9595
type Conn struct {
9696
sess *tdsSession
9797
transactionCtx context.Context
98+
resetSession bool
9899

99100
processQueryText bool
100101
connectionGood bool
101102

102103
outs map[string]interface{}
103104
}
104105

106+
func (c *Conn) ResetSession(ctx context.Context) error {
107+
if !c.connectionGood {
108+
return driver.ErrBadConn
109+
}
110+
c.resetSession = true
111+
112+
return nil
113+
}
114+
105115
func (c *Conn) checkBadConn(err error) error {
106116
// this is a hack to address Issue #275
107117
// we set connectionGood flag to false if
@@ -117,6 +127,7 @@ func (c *Conn) checkBadConn(err error) error {
117127
case nil:
118128
return nil
119129
case io.EOF:
130+
c.connectionGood = false
120131
return driver.ErrBadConn
121132
case driver.ErrBadConn:
122133
// It is an internal programming error if driver.ErrBadConn
@@ -174,7 +185,9 @@ func (c *Conn) sendCommitRequest() error {
174185
{hdrtype: dataStmHdrTransDescr,
175186
data: transDescrHdr{c.sess.tranid, 1}.pack()},
176187
}
177-
if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
188+
reset := c.resetSession
189+
c.resetSession = false
190+
if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
178191
if c.sess.logFlags&logErrors != 0 {
179192
c.sess.log.Printf("Failed to send CommitXact with %v", err)
180193
}
@@ -199,7 +212,9 @@ func (c *Conn) sendRollbackRequest() error {
199212
{hdrtype: dataStmHdrTransDescr,
200213
data: transDescrHdr{c.sess.tranid, 1}.pack()},
201214
}
202-
if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
215+
reset := c.resetSession
216+
c.resetSession = false
217+
if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
203218
if c.sess.logFlags&logErrors != 0 {
204219
c.sess.log.Printf("Failed to send RollbackXact with %v", err)
205220
}
@@ -234,7 +249,9 @@ func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) erro
234249
{hdrtype: dataStmHdrTransDescr,
235250
data: transDescrHdr{0, 1}.pack()},
236251
}
237-
if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, ""); err != nil {
252+
reset := c.resetSession
253+
c.resetSession = false
254+
if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, "", reset); err != nil {
238255
if c.sess.logFlags&logErrors != 0 {
239256
c.sess.log.Printf("Failed to send BeginXact with %v", err)
240257
}
@@ -362,26 +379,30 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) {
362379
})
363380
}
364381

382+
conn := s.c
383+
365384
// no need to check number of parameters here, it is checked by database/sql
366-
if s.c.sess.logFlags&logSQL != 0 {
367-
s.c.sess.log.Println(s.query)
385+
if conn.sess.logFlags&logSQL != 0 {
386+
conn.sess.log.Println(s.query)
368387
}
369-
if s.c.sess.logFlags&logParams != 0 && len(args) > 0 {
388+
if conn.sess.logFlags&logParams != 0 && len(args) > 0 {
370389
for i := 0; i < len(args); i++ {
371390
if len(args[i].Name) > 0 {
372391
s.c.sess.log.Printf("\t@%s\t%v\n", args[i].Name, args[i].Value)
373392
} else {
374393
s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i].Value)
375394
}
376395
}
377-
378396
}
397+
398+
reset := conn.resetSession
399+
conn.resetSession = false
379400
if len(args) == 0 {
380-
if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil {
381-
if s.c.sess.logFlags&logErrors != 0 {
382-
s.c.sess.log.Printf("Failed to send SqlBatch with %v", err)
401+
if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil {
402+
if conn.sess.logFlags&logErrors != 0 {
403+
conn.sess.log.Printf("Failed to send SqlBatch with %v", err)
383404
}
384-
s.c.connectionGood = false
405+
conn.connectionGood = false
385406
return fmt.Errorf("failed to send SQL Batch: %v", err)
386407
}
387408
} else {
@@ -399,11 +420,11 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) {
399420
params[0] = makeStrParam(s.query)
400421
params[1] = makeStrParam(strings.Join(decls, ","))
401422
}
402-
if err = sendRpc(s.c.sess.buf, headers, proc, 0, params); err != nil {
403-
if s.c.sess.logFlags&logErrors != 0 {
404-
s.c.sess.log.Printf("Failed to send Rpc with %v", err)
423+
if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil {
424+
if conn.sess.logFlags&logErrors != 0 {
425+
conn.sess.log.Printf("Failed to send Rpc with %v", err)
405426
}
406-
s.c.connectionGood = false
427+
conn.connectionGood = false
407428
return fmt.Errorf("Failed to send RPC: %v", err)
408429
}
409430
}

mssql_go110.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// +build go1.10
2+
3+
package mssql
4+
5+
import (
6+
"database/sql/driver"
7+
)
8+
9+
var _ driver.Connector = &Connector{}
10+
var _ driver.SessionResetter = &Conn{}

net.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func (c *timeoutConn) Read(b []byte) (n int, err error) {
5858
func (c *timeoutConn) Write(b []byte) (n int, err error) {
5959
if c.buf != nil {
6060
if !c.packetPending {
61-
c.buf.BeginPacket(packPrelogin)
61+
c.buf.BeginPacket(packPrelogin, false)
6262
c.packetPending = true
6363
}
6464
n, err = c.buf.Write(b)

rpc.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ var (
5757
)
5858

5959
// http://msdn.microsoft.com/en-us/library/dd357576.aspx
60-
func sendRpc(buf *tdsBuffer, headers []headerStruct, proc ProcId, flags uint16, params []Param) (err error) {
61-
buf.BeginPacket(packRPCRequest)
60+
func sendRpc(buf *tdsBuffer, headers []headerStruct, proc ProcId, flags uint16, params []Param, resetSession bool) (err error) {
61+
buf.BeginPacket(packRPCRequest, resetSession)
6262
writeAllHeaders(buf, headers)
6363
if len(proc.name) == 0 {
6464
var idswitch uint16 = 0xffff

tds.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ func (p KeySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
162162
func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
163163
var err error
164164

165-
w.BeginPacket(packPrelogin)
165+
w.BeginPacket(packPrelogin, false)
166166
offset := uint16(5*len(fields) + 1)
167167
keys := make(KeySlice, 0, len(fields))
168168
for k, _ := range fields {
@@ -352,7 +352,7 @@ func manglePassword(password string) []byte {
352352

353353
// http://msdn.microsoft.com/en-us/library/dd304019.aspx
354354
func sendLogin(w *tdsBuffer, login login) error {
355-
w.BeginPacket(packLogin7)
355+
w.BeginPacket(packLogin7, false)
356356
hostname := str2ucs2(login.HostName)
357357
username := str2ucs2(login.UserName)
358358
password := manglePassword(login.Password)
@@ -633,8 +633,8 @@ func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
633633
return nil
634634
}
635635

636-
func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct) (err error) {
637-
buf.BeginPacket(packSQLBatch)
636+
func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) {
637+
buf.BeginPacket(packSQLBatch, resetSession)
638638

639639
if err = writeAllHeaders(buf, headers); err != nil {
640640
return
@@ -650,7 +650,7 @@ func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct) (err
650650
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
651651
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
652652
func sendAttention(buf *tdsBuffer) error {
653-
buf.BeginPacket(packAttention)
653+
buf.BeginPacket(packAttention, false)
654654
return buf.FinishPacket()
655655
}
656656

@@ -1337,7 +1337,7 @@ continue_login:
13371337
}
13381338
}
13391339
if sspi_msg != nil {
1340-
outbuf.BeginPacket(packSSPIMessage)
1340+
outbuf.BeginPacket(packSSPIMessage, false)
13411341
_, err = outbuf.Write(sspi_msg)
13421342
if err != nil {
13431343
return nil, err

tds_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func TestSendSqlBatch(t *testing.T) {
8989
{hdrtype: dataStmHdrTransDescr,
9090
data: transDescrHdr{0, 1}.pack()},
9191
}
92-
err = sendSqlBatch72(conn.buf, "select 1", headers)
92+
err = sendSqlBatch72(conn.buf, "select 1", headers, true)
9393
if err != nil {
9494
t.Error("Sending sql batch failed", err.Error())
9595
return

tran.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ const (
2828
isolationSnapshot = 5
2929
)
3030

31-
func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel,
32-
name string) (err error) {
33-
buf.BeginPacket(packTransMgrReq)
31+
func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel, name string, resetSession bool) (err error) {
32+
buf.BeginPacket(packTransMgrReq, resetSession)
3433
writeAllHeaders(buf, headers)
3534
var rqtype uint16 = tmBeginXact
3635
err = binary.Write(buf, binary.LittleEndian, &rqtype)
@@ -52,8 +51,8 @@ const (
5251
fBeginXact = 1
5352
)
5453

55-
func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string) error {
56-
buf.BeginPacket(packTransMgrReq)
54+
func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string, resetSession bool) error {
55+
buf.BeginPacket(packTransMgrReq, resetSession)
5756
writeAllHeaders(buf, headers)
5857
var rqtype uint16 = tmCommitXact
5958
err := binary.Write(buf, binary.LittleEndian, &rqtype)
@@ -81,8 +80,8 @@ func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags u
8180
return buf.FinishPacket()
8281
}
8382

84-
func sendRollbackXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string) error {
85-
buf.BeginPacket(packTransMgrReq)
83+
func sendRollbackXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string, resetSession bool) error {
84+
buf.BeginPacket(packTransMgrReq, resetSession)
8685
writeAllHeaders(buf, headers)
8786
var rqtype uint16 = tmRollbackXact
8887
err := binary.Write(buf, binary.LittleEndian, &rqtype)

0 commit comments

Comments
 (0)