Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Implement ifnull and nullif functions. #555

Merged
merged 3 commits into from
Nov 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,21 @@ go get gopkg.in/src-d/go-mysql-server.v0

We are continuously adding more functionality to go-mysql-server. We support a subset of what is supported in MySQL, to see what is currently included check the [SUPPORTED](./SUPPORTED.md) file.

# Third-party clients
## Third-party clients

We support and actively test against certain third-party clients to ensure compatibility between them and go-mysql-server. You can check out the list of supported third party clients in the [SUPPORTED_CLIENTS](./SUPPORTED_CLIENTS.md) file along with some examples on how to connect to go-mysql-server using them.

## Custom functions

- `COUNT(expr)`: Returns a count of the number of non-NULL values of expr in the rows retrieved by a SELECT statement.
- `MIN(expr)`: Returns the minimum value of expr.
- `MAX(expr)`: Returns the maximum value of expr.
- `AVG(expr)`: Returns the average value of expr.
- `SUM(expr)`: Returns the sum of expr.
- `IS_BINARY(blob)`: Returns whether a BLOB is a binary file or not.
- `SUBSTRING(str, pos)`, `SUBSTRING(str, pos, len)`: Return a substring from the provided string.
- `SUBSTRING(str, pos)`, `SUBSTRING(str, pos, len)` : Return a substring from the provided string.
- `SUBSTR(str, pos)`, `SUBSTR(str, pos, len)` : Return a substring from the provided string.
- `MID(str, pos)`, `MID(str, pos, len)` : Return a substring from the provided string.
- Date and Timestamp functions: `YEAR(date)`, `MONTH(date)`, `DAY(date)`, `WEEKDAY(date)`, `HOUR(date)`, `MINUTE(date)`, `SECOND(date)`, `DAYOFWEEK(date)`, `DAYOFYEAR(date)`.
- `ARRAY_LENGTH(json)`: If the json representation is an array, this function returns its size.
- `SPLIT(str,sep)`: Receives a string and a separator and returns the parts of the string split by the separator as a JSON array of strings.
Expand All @@ -70,6 +77,23 @@ We support and actively test against certain third-party clients to ensure compa
- `ROUND(number, decimals)`: Round the `number` to `decimals` decimal places.
- `CONNECTION_ID()`: Return the current connection ID.
- `SOUNDEX(str)`: Returns the soundex of a string.
- `JSON_EXTRACT(json_doc, path, ...)`: Extracts data from a json document using json paths.
- `LN(X)`: Return the natural logarithm of X.
- `LOG2(X)`: Returns the base-2 logarithm of X.
- `LOG10(X)`: Returns the base-10 logarithm of X.
- `LOG(X), LOG(B, X)`: If called with one parameter, this function returns the natural logarithm of X. If called with two parameters, this function returns the logarithm of X to the base B. If X is less than or equal to 0, or if B is less than or equal to 1, then NULL is returned.
- `RPAD(str, len, padstr)`: Returns the string str, right-padded with the string padstr to a length of len characters.
- `LPAD(str, len, padstr)`: Return the string argument, left-padded with the specified string.
- `SQRT(X)`: Returns the square root of a nonnegative number X.
- `POW(X, Y)`, `POWER(X, Y)`: Returns the value of X raised to the power of Y.
- `TRIM(str)`: Returns the string str with all spaces removed.
- `LTRIM(str)`: Returns the string str with leading space characters removed.
- `RTRIM(str)`: Returns the string str with trailing space characters removed.
- `REVERSE(str)`: Returns the string str with the order of the characters reversed.
- `REPEAT(str, count)`: Returns a string consisting of the string str repeated count times.
- `REPLACE(str,from_str,to_str)`: Returns the string str with all occurrences of the string from_str replaced by the string to_str.
- `IFNULL(expr1, expr2)`: If expr1 is not NULL, IFNULL() returns expr1; otherwise it returns expr2.
- `NULLIF(expr1, expr2)`: Returns NULL if expr1 = expr2 is true, otherwise returns expr1.

## Example

Expand Down
66 changes: 66 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,72 @@ var queries = []struct {
{"tabletest", "InnoDB", "10", "Fixed", int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), int64(0), nil, nil, nil, "utf8_bin", nil, nil},
},
},
{
`SELECT NULL`,
[]sql.Row{
{nil},
},
},
{
`SELECT nullif('abc', NULL)`,
[]sql.Row{
{"abc"},
},
},
{
`SELECT nullif(NULL, NULL)`,
[]sql.Row{
{sql.Null},
},
},
{
`SELECT nullif(NULL, 123)`,
[]sql.Row{
{nil},
},
},
{
`SELECT nullif(123, 123)`,
[]sql.Row{
{sql.Null},
},
},
{
`SELECT nullif(123, 321)`,
[]sql.Row{
{int64(123)},
},
},
{
`SELECT ifnull(123, NULL)`,
[]sql.Row{
{int64(123)},
},
},
{
`SELECT ifnull(NULL, NULL)`,
[]sql.Row{
{nil},
},
},
{
`SELECT ifnull(NULL, 123)`,
[]sql.Row{
{int64(123)},
},
},
{
`SELECT ifnull(123, 123)`,
[]sql.Row{
{int64(123)},
},
},
{
`SELECT ifnull(123, 321)`,
[]sql.Row{
{int64(123)},
},
},
}

func TestQueries(t *testing.T) {
Expand Down
81 changes: 81 additions & 0 deletions sql/expression/function/ifnull.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package function

import (
"fmt"

"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
)

// IfNull function returns the specified value IF the expression is NULL, otherwise return the expression.
type IfNull struct {
expression.BinaryExpression
}

// NewIfNull returns a new IFNULL UDF
func NewIfNull(ex, value sql.Expression) sql.Expression {
return &IfNull{
expression.BinaryExpression{
Left: ex,
Right: value,
},
}
}

// Eval implements the Expression interface.
func (f *IfNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
left, err := f.Left.Eval(ctx, row)
if err != nil {
return nil, err
}
if left != nil {
return left, nil
}

right, err := f.Right.Eval(ctx, row)
if err != nil {
return nil, err
}
return right, nil
}

// Type implements the Expression interface.
func (f *IfNull) Type() sql.Type {
if sql.IsNull(f.Left) {
if sql.IsNull(f.Right) {
return sql.Null
}
return f.Right.Type()
}
return f.Left.Type()
}

// IsNullable implements the Expression interface.
func (f *IfNull) IsNullable() bool {
if sql.IsNull(f.Left) {
if sql.IsNull(f.Right) {
return true
}
return f.Right.IsNullable()
}
return f.Left.IsNullable()
}

func (f *IfNull) String() string {
return fmt.Sprintf("ifnull(%s, %s)", f.Left, f.Right)
}

// TransformUp implements the Expression interface.
func (f *IfNull) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) {
left, err := f.Left.TransformUp(fn)
if err != nil {
return nil, err
}

right, err := f.Right.TransformUp(fn)
if err != nil {
return nil, err
}

return fn(NewIfNull(left, right))
}
36 changes: 36 additions & 0 deletions sql/expression/function/ifnull_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package function

import (
"testing"

"github.com/stretchr/testify/require"
"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
)

func TestIfNull(t *testing.T) {
testCases := []struct {
expression interface{}
value interface{}
expected interface{}
}{
{"foo", "bar", "foo"},
{"foo", "foo", "foo"},
{nil, "foo", "foo"},
{"foo", nil, "foo"},
{nil, nil, nil},
{"", nil, ""},
}

f := NewIfNull(
expression.NewGetField(0, sql.Text, "expression", true),
expression.NewGetField(1, sql.Text, "value", true),
)
require.Equal(t, sql.Text, f.Type())

for _, tc := range testCases {
v, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(tc.expression, tc.value))
require.NoError(t, err)
require.Equal(t, tc.expected, v)
}
}
73 changes: 73 additions & 0 deletions sql/expression/function/nullif.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package function

import (
"fmt"

"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
)

// NullIf function compares two expressions and returns NULL if they are equal. Otherwise, the first expression is returned.
type NullIf struct {
expression.BinaryExpression
}

// NewNullIf returns a new NULLIF UDF
func NewNullIf(ex1, ex2 sql.Expression) sql.Expression {
return &NullIf{
expression.BinaryExpression{
Left: ex1,
Right: ex2,
},
}
}

// Eval implements the Expression interface.
func (f *NullIf) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
if sql.IsNull(f.Left) && sql.IsNull(f.Right) {
return sql.Null, nil
}

val, err := expression.NewEquals(f.Left, f.Right).Eval(ctx, row)
if err != nil {
return nil, err
}
if b, ok := val.(bool); ok && b {
return sql.Null, nil
}

return f.Left.Eval(ctx, row)
}

// Type implements the Expression interface.
func (f *NullIf) Type() sql.Type {
if sql.IsNull(f.Left) {
return sql.Null
}

return f.Left.Type()
}

// IsNullable implements the Expression interface.
func (f *NullIf) IsNullable() bool {
return true
}

func (f *NullIf) String() string {
return fmt.Sprintf("nullif(%s, %s)", f.Left, f.Right)
}

// TransformUp implements the Expression interface.
func (f *NullIf) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) {
left, err := f.Left.TransformUp(fn)
if err != nil {
return nil, err
}

right, err := f.Right.TransformUp(fn)
if err != nil {
return nil, err
}

return fn(NewNullIf(left, right))
}
36 changes: 36 additions & 0 deletions sql/expression/function/nullif_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package function

import (
"testing"

"github.com/stretchr/testify/require"
"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
)

func TestNullIf(t *testing.T) {
testCases := []struct {
ex1 interface{}
ex2 interface{}
expected interface{}
}{
{"foo", "bar", "foo"},
{"foo", "foo", sql.Null},
{nil, "foo", nil},
{"foo", nil, "foo"},
{nil, nil, nil},
{"", nil, ""},
}

f := NewNullIf(
expression.NewGetField(0, sql.Text, "ex1", true),
expression.NewGetField(1, sql.Text, "ex2", true),
)
require.Equal(t, sql.Text, f.Type())

for _, tc := range testCases {
v, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(tc.ex1, tc.ex2))
require.NoError(t, err)
require.Equal(t, tc.expected, v)
}
}
6 changes: 4 additions & 2 deletions sql/expression/function/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ var Defaults = sql.Functions{
"split": sql.Function2(NewSplit),
"concat": sql.FunctionN(NewConcat),
"concat_ws": sql.FunctionN(NewConcatWithSeparator),
"coalesce": sql.FunctionN(NewCoalesce),
"lower": sql.Function1(NewLower),
"upper": sql.Function1(NewUpper),
"ceiling": sql.Function1(NewCeil),
"ceil": sql.Function1(NewCeil),
"floor": sql.Function1(NewFloor),
"round": sql.FunctionN(NewRound),
"coalesce": sql.FunctionN(NewCoalesce),
"json_extract": sql.FunctionN(NewJSONExtract),
"connection_id": sql.Function0(NewConnectionID),
"soundex": sql.Function1(NewSoundex),
"json_extract": sql.FunctionN(NewJSONExtract),
"ln": sql.Function1(NewLogBaseFunc(float64(math.E))),
"log2": sql.Function1(NewLogBaseFunc(float64(2))),
"log10": sql.Function1(NewLogBaseFunc(float64(10))),
Expand All @@ -66,4 +66,6 @@ var Defaults = sql.Functions{
"reverse": sql.Function1(NewReverse),
"repeat": sql.Function2(NewRepeat),
"replace": sql.Function3(NewReplace),
"ifnull": sql.Function2(NewIfNull),
"nullif": sql.Function2(NewNullIf),
}
Loading