-
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Make TryInsert functions within the packages module use INSERT ... ON CONFLICT #21063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 25 commits
7f4e851
b29ab42
470064c
9c83ab8
bf94b55
8d3864a
71522ea
7843fe9
abcf334
430f964
5dff21a
f934a98
b3db6da
5bc4924
8f7987c
fc5d9aa
7ec881f
d8aa794
3f39045
15855df
38d540b
a941cba
5ef7902
04efbf9
2283b23
a282e66
62b1e20
f1222e8
25abc72
028b5a6
1c17006
ac6862a
dc4638c
ecd3eea
bf18cf8
1a6f5df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,12 +4,20 @@ | |
package db | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"reflect" | ||
"strings" | ||
"time" | ||
|
||
"code.gitea.io/gitea/modules/log" | ||
"code.gitea.io/gitea/modules/setting" | ||
"code.gitea.io/gitea/modules/util" | ||
|
||
"xorm.io/builder" | ||
"xorm.io/xorm/convert" | ||
"xorm.io/xorm/dialects" | ||
"xorm.io/xorm/schemas" | ||
) | ||
|
||
// BuildCaseInsensitiveLike returns a condition to check if the given value is like the given key case-insensitively. | ||
|
@@ -20,3 +28,360 @@ func BuildCaseInsensitiveLike(key, value string) builder.Cond { | |
} | ||
return builder.Like{"UPPER(" + key + ")", strings.ToUpper(value)} | ||
} | ||
|
||
// InsertOnConflictDoNothing will attempt to insert the provided bean but if there is a conflict it will not error out | ||
// This function will update the ID of the provided bean if there is an insertion | ||
// This does not do all of the conversions that xorm would do automatically but it does quite a number of them | ||
// once xorm has a working InsertOnConflictDoNothing this function could be removed. | ||
func InsertOnConflictDoNothing(ctx context.Context, bean interface{}) (bool, error) { | ||
e := GetEngine(ctx) | ||
|
||
tableName := x.TableName(bean, true) | ||
table, err := x.TableInfo(bean) | ||
if err != nil { | ||
return false, err | ||
} | ||
|
||
autoIncrCol := table.AutoIncrColumn() | ||
|
||
columns := table.Columns() | ||
|
||
colNames, values, zeroedColNames, zeroedValues, err := getColNamesAndValuesFromBean(bean, columns) | ||
if err != nil { | ||
return false, err | ||
} | ||
|
||
if len(colNames) == 0 { | ||
return false, fmt.Errorf("provided bean to insert has all empty values") | ||
} | ||
|
||
// MSSQL needs to separately pass in the columns with the unique constraint and we need to | ||
// include empty columns which are in the constraint in the insert for other dbs | ||
uniqueCols, uniqueValues, colNames, values := addInUniqueCols(colNames, values, zeroedColNames, zeroedValues, table) | ||
if len(uniqueCols) == 0 { | ||
return false, fmt.Errorf("provided bean has no unique constraints") | ||
} | ||
|
||
var insertArgs []any | ||
|
||
switch { | ||
case setting.Database.UseSQLite3 || setting.Database.UsePostgreSQL || setting.Database.UseMySQL: | ||
insertArgs = generateInsertNoConflictSQLAndArgs(tableName, colNames, values, autoIncrCol) | ||
case setting.Database.UseMSSQL: | ||
insertArgs = generateInsertNoConflictSQLAndArgsForMSSQL(tableName, colNames, values, uniqueCols, uniqueValues, autoIncrCol) | ||
default: | ||
return false, fmt.Errorf("database type not supported") | ||
} | ||
|
||
if autoIncrCol != nil && (setting.Database.UsePostgreSQL || setting.Database.UseMSSQL) { | ||
// Postgres and MSSQL do not use the LastInsertID mechanism | ||
// Therefore use query rather than exec and read the last provided ID back in | ||
|
||
res, err := e.Query(insertArgs...) | ||
if err != nil { | ||
return false, fmt.Errorf("error in query: %s, %w", insertArgs[0], err) | ||
} | ||
if len(res) == 0 { | ||
// this implies there was a conflict | ||
return false, nil | ||
} | ||
|
||
aiValue, err := table.AutoIncrColumn().ValueOf(bean) | ||
if err != nil { | ||
log.Error("unable to get value for autoincrcol of %#v %v", bean, err) | ||
} | ||
|
||
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { | ||
return true, nil | ||
} | ||
|
||
id := res[0][autoIncrCol.Name] | ||
err = convert.AssignValue(*aiValue, id) | ||
if err != nil { | ||
return true, fmt.Errorf("error in assignvalue %v %v %w", id, res, err) | ||
} | ||
return true, nil | ||
} | ||
|
||
res, err := e.Exec(values...) | ||
if err != nil { | ||
return false, err | ||
} | ||
|
||
n, err := res.RowsAffected() | ||
if err != nil { | ||
return n != 0, err | ||
} | ||
|
||
if n != 0 && autoIncrCol != nil { | ||
id, err := res.LastInsertId() | ||
if err != nil { | ||
return true, err | ||
} | ||
reflect.ValueOf(bean).Elem().FieldByName(autoIncrCol.FieldName).SetInt(id) | ||
} | ||
|
||
return n != 0, err | ||
} | ||
|
||
// generateInsertNoConflictSQLAndArgs will create the correct insert code for most of the DBs except MSSQL | ||
func generateInsertNoConflictSQLAndArgs(tableName string, colNames []string, args []any, autoIncrCol *schemas.Column) (insertArgs []any) { | ||
sb := &strings.Builder{} | ||
|
||
quote := x.Dialect().Quoter().Quote | ||
write := func(args ...string) { | ||
for _, arg := range args { | ||
_, _ = sb.WriteString(arg) | ||
} | ||
} | ||
write("INSERT ") | ||
if setting.Database.UseMySQL && autoIncrCol == nil { | ||
write("IGNORE ") | ||
} | ||
write("INTO ", quote(tableName), " (") | ||
_ = x.Dialect().Quoter().JoinWrite(sb, colNames, ",") | ||
write(") VALUES (?") | ||
for range colNames[1:] { | ||
write(",?") | ||
} | ||
switch { | ||
case setting.Database.UsePostgreSQL: | ||
write(") ON CONFLICT DO NOTHING") | ||
if autoIncrCol != nil { | ||
write(" RETURNING ", quote(autoIncrCol.Name)) | ||
} | ||
case setting.Database.UseSQLite3: | ||
write(") ON CONFLICT DO NOTHING") | ||
case setting.Database.UseMySQL: | ||
if autoIncrCol != nil { | ||
write(") ON DUPLICATE KEY UPDATE ", quote(autoIncrCol.Name), " = ", quote(autoIncrCol.Name)) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry to bother, but why If you meant to do it for Otherwise I really can not understand its purpose. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why don't you try it. When I wrote this code it required the ON DUPLICATE KEY UPDATE and not the IGNORE There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (merged into next comment below) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it really works from XORM or Golang MySQL Driver, there should be some complete test cases for it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What you really need is this (update: out-dated)
Table: CREATE TABLE `t` (
`id` int NOT NULL AUTO_INCREMENT,
`k` varchar(100),
`v` varchar(100) DEFAULT '',
`i` int DEFAULT NULL,
PRIMARY KEY (`id`),
UNIQUE KEY `k` (`k`)
) ENGINE=InnoDB AUTO_INCREMENT=9
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a) last_insert_id is the last successfully inserted ID in current session, it's always safe across transactions. And it won't be reset if a new insertion fails. you might have seen some incorrect last_insert_id/unrelated during your test. Update: out-dated. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MySQL demo (update: out-dated, the demo is for filling the ID)
|
||
} | ||
args[0] = sb.String() | ||
return args | ||
} | ||
|
||
// generateInsertNoConflictSQLAndArgsForMSSQL writes the INSERT ... ON CONFLICT sql variant for MSSQL | ||
// MSSQL uses MERGE <tablename> WITH <lock> ... but needs to pre-select the unique cols first | ||
// then WHEN NOT MATCHED INSERT - this is kind of the opposite way round from INSERT ... ON CONFLICT | ||
func generateInsertNoConflictSQLAndArgsForMSSQL(tableName string, colNames []string, args []any, uniqueCols []string, uniqueArgs []any, autoIncrCol *schemas.Column) (insertArgs []any) { | ||
sb := &strings.Builder{} | ||
|
||
quote := x.Dialect().Quoter().Quote | ||
write := func(args ...string) { | ||
for _, arg := range args { | ||
_, _ = sb.WriteString(arg) | ||
} | ||
} | ||
|
||
write("MERGE ", quote(tableName), " WITH (HOLDLOCK) AS target USING (SELECT ? AS ") | ||
_ = x.Dialect().Quoter().JoinWrite(sb, uniqueCols, ", ? AS ") | ||
write(") AS src ON src.", quote(uniqueCols[0]), "= target.", quote(uniqueCols[0])) | ||
for _, uniqueCol := range uniqueCols[1:] { | ||
write(" AND src.", quote(uniqueCol), "= target.", quote(uniqueCol)) | ||
} | ||
write(" WHEN NOT MATCHED THEN INSERT (") | ||
_ = x.Dialect().Quoter().JoinWrite(sb, colNames, ",") | ||
write(") VALUES (?") | ||
for range colNames[1:] { | ||
write(", ?") | ||
} | ||
write(")") | ||
if autoIncrCol != nil { | ||
write(" OUTPUT INSERTED.", quote(autoIncrCol.Name)) | ||
} | ||
write(";") | ||
uniqueArgs[0] = sb.String() | ||
return append(uniqueArgs, args[1:]...) | ||
} | ||
|
||
// addInUniqueCols determines the columns that refer to unique constraints and creates slices for these | ||
// as they're needed by MSSQL. In addition, any columns which are zero-valued but are part of a constraint | ||
// are added back in to the colNames and args | ||
func addInUniqueCols(colNames []string, args []any, zeroedColNames []string, emptyArgs []any, table *schemas.Table) (uniqueCols []string, uniqueArgs []any, insertCols []string, insertArgs []any) { | ||
uniqueCols = make([]string, 0, len(table.Columns())) | ||
uniqueArgs = make([]interface{}, 1, len(uniqueCols)+1) // leave uniqueArgs[0] empty to put the SQL in | ||
|
||
// Iterate across the indexes in the provided table | ||
for _, index := range table.Indexes { | ||
if index.Type != schemas.UniqueType { | ||
continue | ||
} | ||
|
||
// index is a Unique constraint | ||
indexCol: | ||
for _, iCol := range index.Cols { | ||
for _, uCol := range uniqueCols { | ||
if uCol == iCol { | ||
// column is already included in uniqueCols so we don't need to add it again | ||
continue indexCol | ||
} | ||
} | ||
|
||
// Now iterate across colNames and add to the uniqueCols | ||
for i, col := range colNames { | ||
if col == iCol { | ||
uniqueCols = append(uniqueCols, col) | ||
uniqueArgs = append(uniqueArgs, args[i+1]) | ||
continue indexCol | ||
} | ||
} | ||
|
||
// If we still haven't found the column we need to look in the emptyColumns and add | ||
// it back into colNames and args as well as uniqueCols/uniqueArgs | ||
for i, col := range zeroedColNames { | ||
if col == iCol { | ||
// Always include empty unique columns in the insert statement as otherwise the insert no conflict will pass | ||
colNames = append(colNames, col) | ||
args = append(args, emptyArgs[i]) | ||
uniqueCols = append(uniqueCols, col) | ||
uniqueArgs = append(uniqueArgs, emptyArgs[i]) | ||
continue indexCol | ||
} | ||
} | ||
} | ||
} | ||
return uniqueCols, uniqueArgs, colNames, args | ||
} | ||
|
||
// getColNamesAndValuesFromBean reads the provided bean, providing two pairs of linked slices: | ||
// | ||
// - colNames and values | ||
// - zeroedColNames and zeroedValues | ||
// | ||
// colNames contains the names of the columns that have non-zero values in the provided bean | ||
// values contains the values - with one exception - values is 1-based so that values[0] is deliberately left zero | ||
// | ||
// emptyyColNames and zeroedValues accounts for the other columns - with zeroedValues containing the zero values | ||
func getColNamesAndValuesFromBean(bean interface{}, cols []*schemas.Column) (colNames []string, values []any, zeroedColNames []string, zeroedValues []any, err error) { | ||
colNames = make([]string, len(cols)) | ||
values = make([]any, len(cols)+1) // Leave args[0] to put the SQL in | ||
maxNonEmpty := 0 | ||
minEmpty := len(cols) | ||
|
||
val := reflect.ValueOf(bean) | ||
elem := val.Elem() | ||
for _, col := range cols { | ||
if fieldIdx := col.FieldIndex; fieldIdx != nil { | ||
fieldVal := elem.FieldByIndex(fieldIdx) | ||
if col.IsCreated || col.IsUpdated { | ||
result, err := setCurrentTime(fieldVal, col) | ||
if err != nil { | ||
return nil, nil, nil, nil, err | ||
} | ||
|
||
colNames[maxNonEmpty] = col.Name | ||
maxNonEmpty++ | ||
values[maxNonEmpty] = result | ||
continue | ||
} | ||
|
||
val, err := getValueFromField(fieldVal, col) | ||
if err != nil { | ||
return nil, nil, nil, nil, err | ||
} | ||
if fieldVal.IsZero() { | ||
values[minEmpty] = val // remember args is 1-based not 0-based | ||
minEmpty-- | ||
colNames[minEmpty] = col.Name | ||
continue | ||
} | ||
colNames[maxNonEmpty] = col.Name | ||
maxNonEmpty++ | ||
values[maxNonEmpty] = val | ||
} | ||
} | ||
|
||
return colNames[:maxNonEmpty], values[:maxNonEmpty+1], colNames[maxNonEmpty:], values[maxNonEmpty+1:], nil | ||
} | ||
|
||
func setCurrentTime(fieldVal reflect.Value, col *schemas.Column) (interface{}, error) { | ||
t := time.Now() | ||
result, err := dialects.FormatColumnTime(x.Dialect(), x.DatabaseTZ, col, t) | ||
if err != nil { | ||
return result, err | ||
} | ||
|
||
switch fieldVal.Type().Kind() { | ||
case reflect.Struct: | ||
fieldVal.Set(reflect.ValueOf(t).Convert(fieldVal.Type())) | ||
case reflect.Int, reflect.Int64, reflect.Int32: | ||
fieldVal.SetInt(t.Unix()) | ||
case reflect.Uint, reflect.Uint64, reflect.Uint32: | ||
fieldVal.SetUint(uint64(t.Unix())) | ||
} | ||
return result, nil | ||
} | ||
|
||
// getValueFromField extracts the reflected value from the provided fieldVal | ||
// this keeps the type and makes such that zero values work in the SQL Insert above | ||
func getValueFromField(fieldVal reflect.Value, col *schemas.Column) (any, error) { | ||
// Handle pointers to convert.Conversion | ||
if fieldVal.CanAddr() { | ||
if fieldConvert, ok := fieldVal.Addr().Interface().(convert.Conversion); ok { | ||
data, err := fieldConvert.ToDB() | ||
if err != nil { | ||
return nil, err | ||
} | ||
if data == nil { | ||
if col.Nullable { | ||
return nil, nil | ||
} | ||
data = []byte{} | ||
} | ||
if col.SQLType.IsBlob() { | ||
return data, nil | ||
} | ||
return string(data), nil | ||
} | ||
} | ||
|
||
// Handle nil pointer to convert.Conversion | ||
isNil := fieldVal.Kind() == reflect.Ptr && fieldVal.IsNil() | ||
if !isNil { | ||
if fieldConvert, ok := fieldVal.Interface().(convert.Conversion); ok { | ||
data, err := fieldConvert.ToDB() | ||
if err != nil { | ||
return nil, err | ||
} | ||
if data == nil { | ||
if col.Nullable { | ||
return nil, nil | ||
} | ||
data = []byte{} | ||
} | ||
if col.SQLType.IsBlob() { | ||
return data, nil | ||
} | ||
return string(data), nil | ||
} | ||
} | ||
|
||
// Handle common primitive types | ||
switch fieldVal.Type().Kind() { | ||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||
return fieldVal.Int(), nil | ||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | ||
return fieldVal.Uint(), nil | ||
case reflect.Float32, reflect.Float64: | ||
return fieldVal.Float(), nil | ||
case reflect.Complex64, reflect.Complex128: | ||
return fieldVal.Complex(), nil | ||
case reflect.String: | ||
return fieldVal.String(), nil | ||
case reflect.Bool: | ||
valBool := fieldVal.Bool() | ||
|
||
if setting.Database.UseMSSQL { | ||
if valBool { | ||
return 1, nil | ||
} | ||
return 0, nil | ||
} | ||
return valBool, nil | ||
default: | ||
} | ||
|
||
// just return the interface | ||
return fieldVal.Interface(), nil | ||
} |
Uh oh!
There was an error while loading. Please reload this page.