Skip to content

Implement NamedValueChecker for mysqlConn #690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 16, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Jian Zhen <zhenjl at gmail.com>
Joshua Prunier <joshua.prunier at gmail.com>
Julien Lefevre <julien.lefevr at gmail.com>
Julien Schmidt <go-sql-driver at julienschmidt.com>
Justin Li <jli at j-li.net>
Justin Nuß <nuss.justin at gmail.com>
Kamil Dziedzic <kamil at klecza.pl>
Kevin Malachowski <kevin at chowski.com>
Expand Down
5 changes: 5 additions & 0 deletions connection_go18.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,8 @@ func (mc *mysqlConn) startWatcher() {
}
}()
}

func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = converter{}.ConvertValue(nv.Value)
return
}
30 changes: 30 additions & 0 deletions connection_go18_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.

// +build go1.8

package mysql

import (
"database/sql/driver"
"testing"
)

func TestCheckNamedValue(t *testing.T) {
value := driver.NamedValue{Value: ^uint64(0)}
x := &mysqlConn{}
err := x.CheckNamedValue(&value)

if err != nil {
t.Fatal("uint64 high-bit not convertible", err)
}

if value.Value != "18446744073709551615" {
t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value)
}
}
8 changes: 8 additions & 0 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
return int64(u64), nil
case reflect.Float32, reflect.Float64:
return rv.Float(), nil
case reflect.Bool:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this file affect all versions (not just 1.8). Can you explain what this does? In particular what does this do to 1.7 and lower?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nvm I just saw that this is from golang/go@d7c0de9

return rv.Bool(), nil
case reflect.Slice:
ek := rv.Type().Elem().Kind()
if ek == reflect.Uint8 {
return rv.Bytes(), nil
}
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
case reflect.String:
return rv.String(), nil
}
Expand Down
119 changes: 112 additions & 7 deletions statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,119 @@

package mysql

import "testing"
import (
"bytes"
"testing"
)

type customString string
func TestConvertDerivedString(t *testing.T) {
type derived string

func TestConvertValueCustomTypes(t *testing.T) {
var cstr customString = "string"
c := converter{}
if _, err := c.ConvertValue(cstr); err != nil {
t.Errorf("custom string type should be valid")
output, err := converter{}.ConvertValue(derived("value"))
if err != nil {
t.Fatal("Derived string type not convertible", err)
}

if output != "value" {
t.Fatalf("Derived string type not converted, got %#v %T", output, output)
}
}

func TestConvertDerivedByteSlice(t *testing.T) {
type derived []uint8

output, err := converter{}.ConvertValue(derived("value"))
if err != nil {
t.Fatal("Byte slice not convertible", err)
}

if bytes.Compare(output.([]byte), []byte("value")) != 0 {
t.Fatalf("Byte slice not converted, got %#v %T", output, output)
}
}

func TestConvertDerivedUnsupportedSlice(t *testing.T) {
type derived []int

_, err := converter{}.ConvertValue(derived{1})
if err == nil || err.Error() != "unsupported type mysql.derived, a slice of int" {
t.Fatal("Unexpected error", err)
}
}

func TestConvertDerivedBool(t *testing.T) {
type derived bool

output, err := converter{}.ConvertValue(derived(true))
if err != nil {
t.Fatal("Derived bool type not convertible", err)
}

if output != true {
t.Fatalf("Derived bool type not converted, got %#v %T", output, output)
}
}

func TestConvertPointer(t *testing.T) {
str := "value"

output, err := converter{}.ConvertValue(&str)
if err != nil {
t.Fatal("Pointer type not convertible", err)
}

if output != "value" {
t.Fatalf("Pointer type not converted, got %#v %T", output, output)
}
}

func TestConvertSignedIntegers(t *testing.T) {
values := []interface{}{
int8(-42),
int16(-42),
int32(-42),
int64(-42),
int(-42),
}

for _, value := range values {
output, err := converter{}.ConvertValue(value)
if err != nil {
t.Fatalf("%T type not convertible %s", value, err)
}

if output != int64(-42) {
t.Fatalf("%T type not converted, got %#v %T", value, output, output)
}
}
}

func TestConvertUnsignedIntegers(t *testing.T) {
values := []interface{}{
uint8(42),
uint16(42),
uint32(42),
uint64(42),
uint(42),
}

for _, value := range values {
output, err := converter{}.ConvertValue(value)
if err != nil {
t.Fatalf("%T type not convertible %s", value, err)
}

if output != int64(42) {
t.Fatalf("%T type not converted, got %#v %T", value, output, output)
}
}

output, err := converter{}.ConvertValue(^uint64(0))
if err != nil {
t.Fatal("uint64 high-bit not convertible", err)
}

if output != "18446744073709551615" {
t.Fatalf("uint64 high-bit not converted, got %#v %T", output, output)
}
}