Skip to content

Commit 8424410

Browse files
kyleconroyJille
andauthored
feat(mysql): :copyfrom support via LOAD DATA INFILE (#2545)
This enables the :copyfrom query annotation for people using go-sql-driver/mysql that transforms it into a LOAD DATA LOCAL INFILE. We don't have a way to get the timezone from the connection, so I've simply blocked people from using time.Times in their copyfrom. --------- Co-authored-by: Jille Timmermans <[email protected]>
1 parent 39f16cc commit 8424410

File tree

16 files changed

+338
-14
lines changed

16 files changed

+338
-14
lines changed

docs/howto/insert.md

+23-2
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ func (q *Queries) CreateAuthorAndReturnId(ctx context.Context, arg CreateAuthorA
122122

123123
## Using CopyFrom
124124

125-
PostgreSQL supports the Copy Protocol that can insert rows a lot faster than sequential inserts. You can use this easily with sqlc:
125+
PostgreSQL supports the [COPY protocol](https://www.postgresql.org/docs/current/sql-copy.html) that can insert rows a lot faster than sequential inserts. You can use this easily with sqlc:
126126

127127
```sql
128128
CREATE TABLE authors (
@@ -142,6 +142,27 @@ type CreateAuthorsParams struct {
142142
}
143143

144144
func (q *Queries) CreateAuthors(ctx context.Context, arg []CreateAuthorsParams) (int64, error) {
145-
return q.db.CopyFrom(ctx, []string{"authors"}, []string{"name", "bio"}, &iteratorForCreateAuthors{rows: arg})
145+
...
146146
}
147147
```
148+
149+
MySQL supports a similar feature using [LOAD DATA](https://dev.mysql.com/doc/refman/8.0/en/load-data.html).
150+
151+
Errors and duplicate keys are treated as warnings and insertion will
152+
continue, even without an error for some cases. Use this in a transaction
153+
and use SHOW WARNINGS to check for any problems and roll back if you want to.
154+
155+
Check the [error handling](https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-error-handling) documentation for more information.
156+
157+
```sql
158+
CREATE TABLE foo (a text, b integer, c DATETIME, d DATE);
159+
160+
-- name: InsertValues :copyfrom
161+
INSERT INTO foo (a, b, c, d) VALUES (?, ?, ?, ?);
162+
```
163+
164+
```go
165+
func (q *Queries) InsertValues(ctx context.Context, arg []InsertValuesParams) (int64, error) {
166+
...
167+
}
168+
```

internal/codegen/golang/driver.go

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package golang
22

3-
type SQLDriver int
3+
type SQLDriver string
44

55
const (
66
SQLPackagePGXV4 string = "pgx/v4"
@@ -9,13 +9,12 @@ const (
99
)
1010

1111
const (
12-
SQLDriverPGXV4 SQLDriver = iota
13-
SQLDriverPGXV5
14-
SQLDriverLibPQ
12+
SQLDriverPGXV4 SQLDriver = "github.com/jackc/pgx/v4"
13+
SQLDriverPGXV5 = "github.com/jackc/pgx/v5"
14+
SQLDriverLibPQ = "github.com/lib/pq"
15+
SQLDriverGoSQLDriverMySQL = "github.com/go-sql-driver/mysql"
1516
)
1617

17-
const SQLDriverGoSQLDriverMySQL = "github.com/go-sql-driver/mysql"
18-
1918
func parseDriver(sqlPackage string) SQLDriver {
2019
switch sqlPackage {
2120
case SQLPackagePGXV4:
@@ -31,6 +30,10 @@ func (d SQLDriver) IsPGX() bool {
3130
return d == SQLDriverPGXV4 || d == SQLDriverPGXV5
3231
}
3332

33+
func (d SQLDriver) IsGoSQLDriverMySQL() bool {
34+
return d == SQLDriverGoSQLDriverMySQL
35+
}
36+
3437
func (d SQLDriver) Package() string {
3538
switch d {
3639
case SQLDriverPGXV4:

internal/codegen/golang/gen.go

+20-2
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,15 @@ func generate(req *plugin.CodeGenRequest, enums []Enum, structs []Struct, querie
145145
SqlcVersion: req.SqlcVersion,
146146
}
147147

148-
if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() {
149-
return nil, errors.New(":copyfrom is only supported by pgx")
148+
if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && golang.SqlDriver != SQLDriverGoSQLDriverMySQL {
149+
return nil, errors.New(":copyfrom is only supported by pgx and github.com/go-sql-driver/mysql")
150+
}
151+
152+
if tctx.UsesCopyFrom && golang.SqlDriver == SQLDriverGoSQLDriverMySQL {
153+
if err := checkNoTimesForMySQLCopyFrom(queries); err != nil {
154+
return nil, err
155+
}
156+
tctx.SQLDriver = SQLDriverGoSQLDriverMySQL
150157
}
151158

152159
if tctx.UsesBatch && !tctx.SQLDriver.IsPGX() {
@@ -294,6 +301,17 @@ func usesBatch(queries []Query) bool {
294301
return false
295302
}
296303

304+
func checkNoTimesForMySQLCopyFrom(queries []Query) error {
305+
for _, q := range queries {
306+
for _, f := range q.Arg.Fields() {
307+
if f.Type == "time.Time" {
308+
return fmt.Errorf("values with a timezone are not yet supported")
309+
}
310+
}
311+
}
312+
return nil
313+
}
314+
297315
func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query) ([]Enum, []Struct) {
298316
keepTypes := make(map[string]struct{})
299317

internal/codegen/golang/imports.go

+7
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,13 @@ func (i *importer) copyfromImports() fileImports {
407407
})
408408

409409
std["context"] = struct{}{}
410+
if i.Settings.Go.SqlDriver == SQLDriverGoSQLDriverMySQL {
411+
std["io"] = struct{}{}
412+
std["fmt"] = struct{}{}
413+
std["sync/atomic"] = struct{}{}
414+
pkg[ImportSpec{Path: "github.com/go-sql-driver/mysql"}] = struct{}{}
415+
pkg[ImportSpec{Path: "github.com/hexon/mysqltsv"}] = struct{}{}
416+
}
410417

411418
return sortedImports(std, pkg)
412419
}

internal/codegen/golang/query.go

+36-2
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,18 @@ func (v QueryValue) Params() string {
129129
return "\n" + strings.Join(out, ",\n")
130130
}
131131

132-
func (v QueryValue) ColumnNames() string {
132+
func (v QueryValue) ColumnNames() []string {
133+
if v.Struct == nil {
134+
return []string{v.DBName}
135+
}
136+
names := make([]string, len(v.Struct.Fields))
137+
for i, f := range v.Struct.Fields {
138+
names[i] = f.DBName
139+
}
140+
return names
141+
}
142+
143+
func (v QueryValue) ColumnNamesAsGoSlice() string {
133144
if v.Struct == nil {
134145
return fmt.Sprintf("[]string{%q}", v.DBName)
135146
}
@@ -187,6 +198,19 @@ func (v QueryValue) Scan() string {
187198
return "\n" + strings.Join(out, ",\n")
188199
}
189200

201+
func (v QueryValue) Fields() []Field {
202+
if v.Struct != nil {
203+
return v.Struct.Fields
204+
}
205+
return []Field{
206+
{
207+
Name: v.Name,
208+
DBName: v.DBName,
209+
Type: v.Typ,
210+
},
211+
}
212+
}
213+
190214
func (v QueryValue) VariableForField(f Field) string {
191215
if !v.IsStruct() {
192216
return v.Name
@@ -218,7 +242,7 @@ func (q Query) hasRetType() bool {
218242
return scanned && !q.Ret.isEmpty()
219243
}
220244

221-
func (q Query) TableIdentifier() string {
245+
func (q Query) TableIdentifierAsGoSlice() string {
222246
escapedNames := make([]string, 0, 3)
223247
for _, p := range []string{q.Table.Catalog, q.Table.Schema, q.Table.Name} {
224248
if p != "" {
@@ -227,3 +251,13 @@ func (q Query) TableIdentifier() string {
227251
}
228252
return "[]string{" + strings.Join(escapedNames, ", ") + "}"
229253
}
254+
255+
func (q Query) TableIdentifierForMySQL() string {
256+
escapedNames := make([]string, 0, 3)
257+
for _, p := range []string{q.Table.Catalog, q.Table.Schema, q.Table.Name} {
258+
if p != "" {
259+
escapedNames = append(escapedNames, fmt.Sprintf("`%s`", p))
260+
}
261+
}
262+
return strings.Join(escapedNames, ".")
263+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
{{define "copyfromCodeGoSqlDriver"}}
2+
{{range .GoQueries}}
3+
{{if eq .Cmd ":copyfrom" }}
4+
var readerHandlerSequenceFor{{.MethodName}} uint32 = 1
5+
6+
func convertRowsFor{{.MethodName}}(w *io.PipeWriter, {{.Arg.SlicePair}}) {
7+
e := mysqltsv.NewEncoder(w, {{ len .Arg.Fields }}, nil)
8+
for _, row := range {{.Arg.Name}} {
9+
{{- with $arg := .Arg }}
10+
{{- range $arg.Fields}}
11+
{{- if eq .Type "string"}}
12+
e.AppendString({{if eq (len $arg.Fields) 1}}row{{else}}row.{{.Name}}{{end}})
13+
{{- else if eq .Type "[]byte"}}
14+
e.AppendBytes({{if eq (len $arg.Fields) 1}}row{{else}}row.{{.Name}}{{end}})
15+
{{- else}}
16+
e.AppendValue({{if eq (len $arg.Fields) 1}}row{{else}}row.{{.Name}}{{end}})
17+
{{- end}}
18+
{{- end}}
19+
{{- end}}
20+
}
21+
w.CloseWithError(e.Close())
22+
}
23+
24+
{{range .Comments}}//{{.}}
25+
{{end -}}
26+
// {{.MethodName}} uses MySQL's LOAD DATA LOCAL INFILE and is not atomic.
27+
//
28+
// Errors and duplicate keys are treated as warnings and insertion will
29+
// continue, even without an error for some cases. Use this in a transaction
30+
// and use SHOW WARNINGS to check for any problems and roll back if you want to.
31+
//
32+
// Check the documentation for more information:
33+
// https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-error-handling
34+
func (q *Queries) {{.MethodName}}(ctx context.Context{{if $.EmitMethodsWithDBArgument}}, db DBTX{{end}}, {{.Arg.SlicePair}}) (int64, error) {
35+
pr, pw := io.Pipe()
36+
defer pr.Close()
37+
rh := fmt.Sprintf("{{.MethodName}}_%d", atomic.AddUint32(&readerHandlerSequenceFor{{.MethodName}}, 1))
38+
mysql.RegisterReaderHandler(rh, func() io.Reader { return pr })
39+
defer mysql.DeregisterReaderHandler(rh)
40+
go convertRowsFor{{.MethodName}}(pw, {{.Arg.Name}})
41+
// The string interpolation is necessary because LOAD DATA INFILE requires
42+
// the file name to be given as a literal string.
43+
result, err := {{if (not $.EmitMethodsWithDBArgument)}}q.{{end}}db.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE {{.TableIdentifierForMySQL}} %s ({{range $index, $name := .Arg.ColumnNames}}{{if gt $index 0}}, {{end}}{{$name}}{{end}})", "Reader::" + rh, mysqltsv.Escaping))
44+
if err != nil {
45+
return 0, err
46+
}
47+
return result.RowsAffected()
48+
}
49+
50+
{{end}}
51+
{{end}}
52+
{{end}}

internal/codegen/golang/templates/pgx/copyfromCopy.tmpl

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ func (r iteratorFor{{.MethodName}}) Err() error {
3939
{{end -}}
4040
{{- if $.EmitMethodsWithDBArgument -}}
4141
func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) (int64, error) {
42-
return db.CopyFrom(ctx, {{.TableIdentifier}}, {{.Arg.ColumnNames}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})
42+
return db.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})
4343
{{- else -}}
4444
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) {
45-
return q.db.CopyFrom(ctx, {{.TableIdentifier}}, {{.Arg.ColumnNames}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})
45+
return q.db.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})
4646
{{- end}}
4747
}
4848

internal/codegen/golang/templates/template.tmpl

+2
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ import (
186186
{{define "copyfromCode"}}
187187
{{if .SQLDriver.IsPGX }}
188188
{{- template "copyfromCodePgx" .}}
189+
{{else if .SQLDriver.IsGoSQLDriverMySQL }}
190+
{{- template "copyfromCodeGoSqlDriver" .}}
189191
{{end}}
190192
{{end}}
191193

internal/endtoend/testdata/copyfrom/mysql/go/copyfrom.go

+88
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/copyfrom/mysql/go/db.go

+31
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)