Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit ab5656f

Browse files
authored
sql: implement EXPLODE and generators (#720)
sql: implement EXPLODE and generators
2 parents 4dcaf78 + 728f747 commit ab5656f

15 files changed

+1010
-66
lines changed

engine_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2459,6 +2459,66 @@ func TestDescribeNoPruneColumns(t *testing.T) {
24592459
require.Len(p.Schema(), 3)
24602460
}
24612461

2462+
var generatorQueries = []struct {
2463+
query string
2464+
expected []sql.Row
2465+
}{
2466+
{
2467+
`SELECT a, EXPLODE(b), c FROM t`,
2468+
[]sql.Row{
2469+
{int64(1), "a", "first"},
2470+
{int64(1), "b", "first"},
2471+
{int64(2), "c", "second"},
2472+
{int64(2), "d", "second"},
2473+
{int64(3), "e", "third"},
2474+
{int64(3), "f", "third"},
2475+
},
2476+
},
2477+
{
2478+
`SELECT a, EXPLODE(b) AS x, c FROM t`,
2479+
[]sql.Row{
2480+
{int64(1), "a", "first"},
2481+
{int64(1), "b", "first"},
2482+
{int64(2), "c", "second"},
2483+
{int64(2), "d", "second"},
2484+
{int64(3), "e", "third"},
2485+
{int64(3), "f", "third"},
2486+
},
2487+
},
2488+
{
2489+
`SELECT a, EXPLODE(b) AS x, c FROM t WHERE x = 'e'`,
2490+
[]sql.Row{
2491+
{int64(3), "e", "third"},
2492+
},
2493+
},
2494+
}
2495+
2496+
func TestGenerators(t *testing.T) {
2497+
table := mem.NewPartitionedTable("t", sql.Schema{
2498+
{Name: "a", Type: sql.Int64, Source: "t"},
2499+
{Name: "b", Type: sql.Array(sql.Text), Source: "t"},
2500+
{Name: "c", Type: sql.Text, Source: "t"},
2501+
}, testNumPartitions)
2502+
2503+
insertRows(
2504+
t, table,
2505+
sql.NewRow(int64(1), []interface{}{"a", "b"}, "first"),
2506+
sql.NewRow(int64(2), []interface{}{"c", "d"}, "second"),
2507+
sql.NewRow(int64(3), []interface{}{"e", "f"}, "third"),
2508+
)
2509+
2510+
db := mem.NewDatabase("db")
2511+
db.AddTable("t", table)
2512+
2513+
catalog := sql.NewCatalog()
2514+
catalog.AddDatabase(db)
2515+
e := sqle.New(catalog, analyzer.NewDefault(catalog), new(sqle.Config))
2516+
2517+
for _, q := range generatorQueries {
2518+
testQuery(t, e, q.query, q.expected)
2519+
}
2520+
}
2521+
24622522
func insertRows(t *testing.T, table sql.Inserter, rows ...sql.Row) {
24632523
t.Helper()
24642524

server/handler.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,12 @@ func (h *Handler) ComQuery(
125125
return err
126126
}
127127

128-
r.Rows = append(r.Rows, rowToSQL(schema, row))
128+
outputRow, err := rowToSQL(schema, row)
129+
if err != nil {
130+
return err
131+
}
132+
133+
r.Rows = append(r.Rows, outputRow)
129134
r.RowsAffected++
130135
}
131136

@@ -203,13 +208,17 @@ func (h *Handler) handleKill(conn *mysql.Conn, query string) (bool, error) {
203208
return true, nil
204209
}
205210

206-
func rowToSQL(s sql.Schema, row sql.Row) []sqltypes.Value {
211+
func rowToSQL(s sql.Schema, row sql.Row) ([]sqltypes.Value, error) {
207212
o := make([]sqltypes.Value, len(row))
213+
var err error
208214
for i, v := range row {
209-
o[i] = s[i].Type.SQL(v)
215+
o[i], err = s[i].Type.SQL(v)
216+
if err != nil {
217+
return nil, err
218+
}
210219
}
211220

212-
return o
221+
return o, nil
213222
}
214223

215224
func schemaToFields(s sql.Schema) []*query.Field {

sql/analyzer/resolve_generators.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package analyzer
2+
3+
import (
4+
"gopkg.in/src-d/go-errors.v1"
5+
"github.com/src-d/go-mysql-server/sql"
6+
"github.com/src-d/go-mysql-server/sql/expression"
7+
"github.com/src-d/go-mysql-server/sql/expression/function"
8+
"github.com/src-d/go-mysql-server/sql/plan"
9+
)
10+
11+
var (
12+
errMultipleGenerators = errors.NewKind("there can't be more than 1 instance of EXPLODE in a SELECT")
13+
errExplodeNotArray = errors.NewKind("argument of type %q given to EXPLODE, expecting array")
14+
)
15+
16+
func resolveGenerators(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
17+
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
18+
p, ok := n.(*plan.Project)
19+
if !ok {
20+
return n, nil
21+
}
22+
23+
projection := p.Projections
24+
25+
g, err := findGenerator(projection)
26+
if err != nil {
27+
return nil, err
28+
}
29+
30+
// There might be no generator in the project, in that case we don't
31+
// have to do anything.
32+
if g == nil {
33+
return n, nil
34+
}
35+
36+
projection[g.idx] = g.expr
37+
38+
var name string
39+
if n, ok := g.expr.(sql.Nameable); ok {
40+
name = n.Name()
41+
} else {
42+
name = g.expr.String()
43+
}
44+
45+
return plan.NewGenerate(
46+
plan.NewProject(projection, p.Child),
47+
expression.NewGetField(g.idx, g.expr.Type(), name, g.expr.IsNullable()),
48+
), nil
49+
})
50+
}
51+
52+
type generator struct {
53+
idx int
54+
expr sql.Expression
55+
}
56+
57+
// findGenerator will find in the given projection a generator column. If there
58+
// is no generator, it will return nil.
59+
// If there are is than one generator or the argument to explode is not an
60+
// array it will fail.
61+
// All occurrences of Explode will be replaced with Generate.
62+
func findGenerator(exprs []sql.Expression) (*generator, error) {
63+
var g = &generator{idx: -1}
64+
for i, e := range exprs {
65+
var found bool
66+
switch e := e.(type) {
67+
case *function.Explode:
68+
found = true
69+
g.expr = function.NewGenerate(e.Child)
70+
case *expression.Alias:
71+
if exp, ok := e.Child.(*function.Explode); ok {
72+
found = true
73+
g.expr = expression.NewAlias(
74+
function.NewGenerate(exp.Child),
75+
e.Name(),
76+
)
77+
}
78+
}
79+
80+
if found {
81+
if g.idx >= 0 {
82+
return nil, errMultipleGenerators.New()
83+
}
84+
g.idx = i
85+
86+
if !sql.IsArray(g.expr.Type()) {
87+
return nil, errExplodeNotArray.New(g.expr.Type())
88+
}
89+
}
90+
}
91+
92+
if g.expr == nil {
93+
return nil, nil
94+
}
95+
96+
return g, nil
97+
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package analyzer
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
"gopkg.in/src-d/go-errors.v1"
8+
"github.com/src-d/go-mysql-server/sql"
9+
"github.com/src-d/go-mysql-server/sql/expression"
10+
"github.com/src-d/go-mysql-server/sql/expression/function"
11+
"github.com/src-d/go-mysql-server/sql/plan"
12+
)
13+
14+
func TestResolveGenerators(t *testing.T) {
15+
testCases := []struct {
16+
name string
17+
node sql.Node
18+
expected sql.Node
19+
err *errors.Kind
20+
}{
21+
{
22+
name: "regular explode",
23+
node: plan.NewProject(
24+
[]sql.Expression{
25+
expression.NewGetField(0, sql.Int64, "a", false),
26+
function.NewExplode(expression.NewGetField(1, sql.Array(sql.Int64), "b", false)),
27+
expression.NewGetField(2, sql.Int64, "c", false),
28+
},
29+
plan.NewUnresolvedTable("foo", ""),
30+
),
31+
expected: plan.NewGenerate(
32+
plan.NewProject(
33+
[]sql.Expression{
34+
expression.NewGetField(0, sql.Int64, "a", false),
35+
function.NewGenerate(expression.NewGetField(1, sql.Array(sql.Int64), "b", false)),
36+
expression.NewGetField(2, sql.Int64, "c", false),
37+
},
38+
plan.NewUnresolvedTable("foo", ""),
39+
),
40+
expression.NewGetField(1, sql.Array(sql.Int64), "EXPLODE(b)", false),
41+
),
42+
err: nil,
43+
},
44+
{
45+
name: "explode with alias",
46+
node: plan.NewProject(
47+
[]sql.Expression{
48+
expression.NewGetField(0, sql.Int64, "a", false),
49+
expression.NewAlias(
50+
function.NewExplode(
51+
expression.NewGetField(1, sql.Array(sql.Int64), "b", false),
52+
),
53+
"x",
54+
),
55+
expression.NewGetField(2, sql.Int64, "c", false),
56+
},
57+
plan.NewUnresolvedTable("foo", ""),
58+
),
59+
expected: plan.NewGenerate(
60+
plan.NewProject(
61+
[]sql.Expression{
62+
expression.NewGetField(0, sql.Int64, "a", false),
63+
expression.NewAlias(
64+
function.NewGenerate(
65+
expression.NewGetField(1, sql.Array(sql.Int64), "b", false),
66+
),
67+
"x",
68+
),
69+
expression.NewGetField(2, sql.Int64, "c", false),
70+
},
71+
plan.NewUnresolvedTable("foo", ""),
72+
),
73+
expression.NewGetField(1, sql.Array(sql.Int64), "x", false),
74+
),
75+
err: nil,
76+
},
77+
{
78+
name: "non array type on explode",
79+
node: plan.NewProject(
80+
[]sql.Expression{
81+
expression.NewGetField(0, sql.Int64, "a", false),
82+
function.NewExplode(expression.NewGetField(1, sql.Int64, "b", false)),
83+
},
84+
plan.NewUnresolvedTable("foo", ""),
85+
),
86+
expected: nil,
87+
err: errExplodeNotArray,
88+
},
89+
{
90+
name: "more than one generator",
91+
node: plan.NewProject(
92+
[]sql.Expression{
93+
expression.NewGetField(0, sql.Int64, "a", false),
94+
function.NewExplode(expression.NewGetField(1, sql.Array(sql.Int64), "b", false)),
95+
function.NewExplode(expression.NewGetField(2, sql.Array(sql.Int64), "c", false)),
96+
},
97+
plan.NewUnresolvedTable("foo", ""),
98+
),
99+
expected: nil,
100+
err: errMultipleGenerators,
101+
},
102+
}
103+
104+
for _, tt := range testCases {
105+
t.Run(tt.name, func(t *testing.T) {
106+
require := require.New(t)
107+
result, err := resolveGenerators(sql.NewEmptyContext(), nil, tt.node)
108+
if tt.err != nil {
109+
require.Error(err)
110+
require.True(tt.err.Is(err))
111+
} else {
112+
require.NoError(err)
113+
require.Equal(tt.expected, result)
114+
}
115+
})
116+
}
117+
}

sql/analyzer/rules.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ var OnceBeforeDefault = []Rule{
3434
// OnceAfterDefault contains the rules to be applied just once after the
3535
// DefaultRules.
3636
var OnceAfterDefault = []Rule{
37+
{"resolve_generators", resolveGenerators},
3738
{"remove_unnecessary_converts", removeUnnecessaryConverts},
3839
{"assign_catalog", assignCatalog},
3940
{"prune_columns", pruneColumns},

sql/analyzer/validation_rules.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ const (
1919
validateIndexCreationRule = "validate_index_creation"
2020
validateCaseResultTypesRule = "validate_case_result_types"
2121
validateIntervalUsageRule = "validate_interval_usage"
22+
validateExplodeUsageRule = "validate_explode_usage"
2223
)
2324

2425
var (
@@ -51,6 +52,11 @@ var (
5152
"invalid use of an interval, which can only be used with DATE_ADD, " +
5253
"DATE_SUB and +/- operators to subtract from or add to a date",
5354
)
55+
// ErrExplodeInvalidUse is returned when an EXPLODE function is used
56+
// outside a Project node.
57+
ErrExplodeInvalidUse = errors.NewKind(
58+
"using EXPLODE is not supported outside a Project node",
59+
)
5460
)
5561

5662
// DefaultValidationRules to apply while analyzing nodes.
@@ -63,6 +69,7 @@ var DefaultValidationRules = []Rule{
6369
{validateIndexCreationRule, validateIndexCreation},
6470
{validateCaseResultTypesRule, validateCaseResultTypes},
6571
{validateIntervalUsageRule, validateIntervalUsage},
72+
{validateExplodeUsageRule, validateExplodeUsage},
6673
}
6774

6875
func validateIsResolved(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
@@ -290,6 +297,31 @@ func validateIntervalUsage(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node,
290297
return n, nil
291298
}
292299

300+
func validateExplodeUsage(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
301+
var invalid bool
302+
plan.InspectExpressions(n, func(e sql.Expression) bool {
303+
// If it's already invalid just skip everything else.
304+
if invalid {
305+
return false
306+
}
307+
308+
// All usage of Explode will be incorrect because the ones in projects
309+
// would have already been converted to Generate, so we only have to
310+
// look for those.
311+
if _, ok := e.(*function.Explode); ok {
312+
invalid = true
313+
}
314+
315+
return true
316+
})
317+
318+
if invalid {
319+
return nil, ErrExplodeInvalidUse.New()
320+
}
321+
322+
return n, nil
323+
}
324+
293325
func stringContains(strs []string, target string) bool {
294326
for _, s := range strs {
295327
if s == target {

0 commit comments

Comments
 (0)