Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 3fe0e4c

Browse files
authored
sql/analyzer: alias projected columns wrapped in converts (#701)
sql/analyzer: alias projected columns wrapped in converts
2 parents bb8122d + cd4137e commit 3fe0e4c

File tree

4 files changed

+224
-22
lines changed

4 files changed

+224
-22
lines changed

engine_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,14 @@ var queries = []struct {
10521052
"SELECT DATE_ADD('9999-12-31 23:59:59', INTERVAL 1 DAY)",
10531053
[]sql.Row{{nil}},
10541054
},
1055+
{
1056+
`SELECT t.date_col FROM (SELECT CONVERT('2019-06-06 00:00:00', DATETIME) as date_col) t WHERE t.date_col > '0000-01-01 00:00:00'`,
1057+
[]sql.Row{{time.Date(2019, time.June, 6, 0, 0, 0, 0, time.UTC)}},
1058+
},
1059+
{
1060+
`SELECT t.date_col FROM (SELECT CONVERT('2019-06-06 00:00:00', DATETIME) as date_col) t GROUP BY t.date_col`,
1061+
[]sql.Row{{time.Date(2019, time.June, 6, 0, 0, 0, 0, time.UTC)}},
1062+
},
10551063
}
10561064

10571065
func TestQueries(t *testing.T) {

sql/analyzer/convert_dates.go

Lines changed: 137 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,156 @@
11
package analyzer
22

33
import (
4+
"fmt"
5+
46
"gopkg.in/src-d/go-mysql-server.v0/sql"
57
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
68
"gopkg.in/src-d/go-mysql-server.v0/sql/expression/function"
9+
"gopkg.in/src-d/go-mysql-server.v0/sql/plan"
710
)
811

12+
type tableCol struct {
13+
table string
14+
col string
15+
}
16+
917
// convertDates wraps all expressions of date and datetime type with converts
1018
// to ensure the date range is validated.
1119
func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
1220
if !n.Resolved() {
1321
return n, nil
1422
}
1523

16-
return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) {
17-
// No need to wrap expressions that already validate times, such as
18-
// convert, date_add, etc and those expressions whose Type method
19-
// cannot be called because they are placeholders.
20-
switch e.(type) {
21-
case *expression.Convert,
22-
*expression.Arithmetic,
23-
*function.DateAdd,
24-
*function.DateSub,
25-
*expression.Star,
26-
*expression.DefaultColumn,
27-
*expression.Alias:
28-
return e, nil
29-
default:
30-
switch e.Type() {
31-
case sql.Date:
32-
return expression.NewConvert(e, expression.ConvertToDate), nil
33-
case sql.Timestamp:
34-
return expression.NewConvert(e, expression.ConvertToDatetime), nil
35-
default:
36-
return e, nil
24+
// Replacements contains a mapping from columns to the alias they will be
25+
// replaced by.
26+
var replacements = make(map[tableCol]string)
27+
28+
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
29+
exp, ok := n.(sql.Expressioner)
30+
if !ok {
31+
return n, nil
32+
}
33+
34+
// nodeReplacements are all the replacements found in the current node.
35+
// These replacements are not applied to the current node, only to
36+
// parent nodes.
37+
var nodeReplacements = make(map[tableCol]string)
38+
39+
var expressions = make(map[string]bool)
40+
switch exp := exp.(type) {
41+
case *plan.Project:
42+
for _, e := range exp.Projections {
43+
expressions[e.String()] = true
44+
}
45+
case *plan.GroupBy:
46+
for _, e := range exp.Aggregate {
47+
expressions[e.String()] = true
3748
}
3849
}
50+
51+
var result sql.Node
52+
var err error
53+
switch exp := exp.(type) {
54+
case *plan.GroupBy:
55+
var aggregate = make([]sql.Expression, len(exp.Aggregate))
56+
for i, a := range exp.Aggregate {
57+
agg, err := a.TransformUp(func(e sql.Expression) (sql.Expression, error) {
58+
return addDateConvert(e, exp, replacements, nodeReplacements, expressions, true)
59+
})
60+
if err != nil {
61+
return nil, err
62+
}
63+
aggregate[i] = agg
64+
}
65+
66+
var grouping = make([]sql.Expression, len(exp.Grouping))
67+
for i, g := range exp.Grouping {
68+
gr, err := g.TransformUp(func(e sql.Expression) (sql.Expression, error) {
69+
return addDateConvert(e, exp, replacements, nodeReplacements, expressions, false)
70+
})
71+
if err != nil {
72+
return nil, err
73+
}
74+
grouping[i] = gr
75+
}
76+
77+
result = plan.NewGroupBy(aggregate, grouping, exp.Child)
78+
default:
79+
result, err = exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
80+
return addDateConvert(e, n, replacements, nodeReplacements, expressions, true)
81+
})
82+
}
83+
84+
if err != nil {
85+
return nil, err
86+
}
87+
88+
// We're done with this node, so copy all the replacements found in
89+
// this node to the global replacements in order to make the necesssary
90+
// changes in parent nodes.
91+
for tc, n := range nodeReplacements {
92+
replacements[tc] = n
93+
}
94+
95+
return result, err
3996
})
4097
}
98+
99+
func addDateConvert(
100+
e sql.Expression,
101+
node sql.Node,
102+
replacements, nodeReplacements map[tableCol]string,
103+
expressions map[string]bool,
104+
aliasRootProjections bool,
105+
) (sql.Expression, error) {
106+
var result sql.Expression
107+
108+
// No need to wrap expressions that already validate times, such as
109+
// convert, date_add, etc and those expressions whose Type method
110+
// cannot be called because they are placeholders.
111+
switch e.(type) {
112+
case *expression.Convert,
113+
*expression.Arithmetic,
114+
*function.DateAdd,
115+
*function.DateSub,
116+
*expression.Star,
117+
*expression.DefaultColumn,
118+
*expression.Alias:
119+
return e, nil
120+
default:
121+
// If it's a replacement, just replace it with the correct GetField
122+
// because we know that it's already converted to a correct date
123+
// and there is no point to do so again.
124+
if gf, ok := e.(*expression.GetField); ok {
125+
if name, ok := replacements[tableCol{gf.Table(), gf.Name()}]; ok {
126+
return expression.NewGetField(gf.Index(), gf.Type(), name, gf.IsNullable()), nil
127+
}
128+
}
129+
130+
switch e.Type() {
131+
case sql.Date:
132+
result = expression.NewConvert(e, expression.ConvertToDate)
133+
case sql.Timestamp:
134+
result = expression.NewConvert(e, expression.ConvertToDatetime)
135+
default:
136+
result = e
137+
}
138+
}
139+
140+
// Only do this if it's a root expression in a project or group by.
141+
switch node.(type) {
142+
case *plan.Project, *plan.GroupBy:
143+
// If it was originally a GetField, and it's not anymore it's
144+
// because we wrapped it in a convert. We need to make it an alias
145+
// and propagate the changes up the chain.
146+
if gf, ok := e.(*expression.GetField); ok && expressions[e.String()] && aliasRootProjections {
147+
if _, ok := result.(*expression.GetField); !ok {
148+
name := fmt.Sprintf("%s__%s", gf.Table(), gf.Name())
149+
result = expression.NewAlias(result, name)
150+
nodeReplacements[tableCol{gf.Table(), gf.Name()}] = name
151+
}
152+
}
153+
}
154+
155+
return result, nil
156+
}

sql/analyzer/convert_dates_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,84 @@ func TestConvertDates(t *testing.T) {
146146
}
147147
}
148148

149+
func TestConvertDatesProject(t *testing.T) {
150+
table := plan.NewResolvedTable(mem.NewTable("t", nil))
151+
input := plan.NewFilter(
152+
expression.NewEquals(
153+
expression.NewGetField(0, sql.Int64, "foo", false),
154+
expression.NewLiteral("2019-06-06 00:00:00", sql.Text),
155+
),
156+
plan.NewProject([]sql.Expression{
157+
expression.NewGetField(0, sql.Timestamp, "foo", false),
158+
}, table),
159+
)
160+
expected := plan.NewFilter(
161+
expression.NewEquals(
162+
expression.NewGetField(0, sql.Int64, "__foo", false),
163+
expression.NewLiteral("2019-06-06 00:00:00", sql.Text),
164+
),
165+
plan.NewProject([]sql.Expression{
166+
expression.NewAlias(
167+
expression.NewConvert(
168+
expression.NewGetField(0, sql.Timestamp, "foo", false),
169+
expression.ConvertToDatetime,
170+
),
171+
"__foo",
172+
),
173+
}, table),
174+
)
175+
176+
result, err := convertDates(sql.NewEmptyContext(), nil, input)
177+
require.NoError(t, err)
178+
require.Equal(t, expected, result)
179+
}
180+
181+
func TestConvertDatesGroupBy(t *testing.T) {
182+
table := plan.NewResolvedTable(mem.NewTable("t", nil))
183+
input := plan.NewFilter(
184+
expression.NewEquals(
185+
expression.NewGetField(0, sql.Int64, "foo", false),
186+
expression.NewLiteral("2019-06-06 00:00:00", sql.Text),
187+
),
188+
plan.NewGroupBy(
189+
[]sql.Expression{
190+
expression.NewGetField(0, sql.Timestamp, "foo", false),
191+
},
192+
[]sql.Expression{
193+
expression.NewGetField(0, sql.Timestamp, "foo", false),
194+
}, table,
195+
),
196+
)
197+
expected := plan.NewFilter(
198+
expression.NewEquals(
199+
expression.NewGetField(0, sql.Int64, "__foo", false),
200+
expression.NewLiteral("2019-06-06 00:00:00", sql.Text),
201+
),
202+
plan.NewGroupBy(
203+
[]sql.Expression{
204+
expression.NewAlias(
205+
expression.NewConvert(
206+
expression.NewGetField(0, sql.Timestamp, "foo", false),
207+
expression.ConvertToDatetime,
208+
),
209+
"__foo",
210+
),
211+
},
212+
[]sql.Expression{
213+
expression.NewConvert(
214+
expression.NewGetField(0, sql.Timestamp, "foo", false),
215+
expression.ConvertToDatetime,
216+
),
217+
},
218+
table,
219+
),
220+
)
221+
222+
result, err := convertDates(sql.NewEmptyContext(), nil, input)
223+
require.NoError(t, err)
224+
require.Equal(t, expected, result)
225+
}
226+
149227
func newDateAdd(l, r sql.Expression) sql.Expression {
150228
e, _ := function.NewDateAdd(l, r)
151229
return e

sql/analyzer/rules.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ var DefaultRules = []Rule{
2020
{"reorder_projection", reorderProjection},
2121
{"move_join_conds_to_filter", moveJoinConditionsToFilter},
2222
{"eval_filter", evalFilter},
23-
{"convert_dates", convertDates},
2423
{"optimize_distinct", optimizeDistinct},
2524
}
2625

@@ -36,6 +35,7 @@ var OnceBeforeDefault = []Rule{
3635
// DefaultRules.
3736
var OnceAfterDefault = []Rule{
3837
{"remove_unnecessary_converts", removeUnnecessaryConverts},
38+
{"convert_dates", convertDates},
3939
{"assign_catalog", assignCatalog},
4040
{"prune_columns", pruneColumns},
4141
{"pushdown", pushdown},

0 commit comments

Comments
 (0)