Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 33153ed

Browse files
authored
Merge pull request #526 from erizocosmico/feature/unnecessary-casts
sql/analyzer: add rule to avoid unnecessary casts
2 parents 4a751c0 + d7e7d21 commit 33153ed

File tree

3 files changed

+67
-0
lines changed

3 files changed

+67
-0
lines changed

sql/analyzer/optimization_rules.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,25 @@ func moveJoinConditionsToFilter(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.
245245
})
246246
}
247247

248+
func removeUnnecessaryConverts(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
249+
span, _ := ctx.Span("remove_unnecessary_converts")
250+
defer span.Finish()
251+
252+
if !n.Resolved() {
253+
return n, nil
254+
}
255+
256+
a.Log("removing unnecessary converts, node of type: %T", n)
257+
258+
return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) {
259+
if c, ok := e.(*expression.Convert); ok && c.Child.Type() == c.Type() {
260+
return c.Child, nil
261+
}
262+
263+
return e, nil
264+
})
265+
}
266+
248267
// containsSources checks that all `needle` sources are contained inside `haystack`.
249268
func containsSources(haystack, needle []string) bool {
250269
for _, s := range needle {

sql/analyzer/optimization_rules_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,3 +396,50 @@ func TestEvalFilter(t *testing.T) {
396396
})
397397
}
398398
}
399+
400+
func TestRemoveUnnecessaryConverts(t *testing.T) {
401+
testCases := []struct {
402+
name string
403+
childExpr sql.Expression
404+
castType string
405+
expected sql.Expression
406+
}{
407+
{
408+
"unnecessary cast",
409+
expression.NewLiteral([]byte{}, sql.Blob),
410+
"binary",
411+
expression.NewLiteral([]byte{}, sql.Blob),
412+
},
413+
{
414+
"necessary cast",
415+
expression.NewLiteral("foo", sql.Text),
416+
"signed",
417+
expression.NewConvert(
418+
expression.NewLiteral("foo", sql.Text),
419+
"signed",
420+
),
421+
},
422+
}
423+
424+
for _, tt := range testCases {
425+
t.Run(tt.name, func(t *testing.T) {
426+
require := require.New(t)
427+
428+
node := plan.NewProject([]sql.Expression{
429+
expression.NewConvert(tt.childExpr, tt.castType),
430+
},
431+
plan.NewResolvedTable(mem.NewTable("foo", nil)),
432+
)
433+
434+
result, err := removeUnnecessaryConverts(
435+
sql.NewEmptyContext(),
436+
NewDefault(nil),
437+
node,
438+
)
439+
require.NoError(err)
440+
441+
resultExpr := result.(*plan.Project).Projections[0]
442+
require.Equal(tt.expected, resultExpr)
443+
})
444+
}
445+
}

sql/analyzer/rules.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ var OnceBeforeDefault = []Rule{
3232
// OnceAfterDefault contains the rules to be applied just once after the
3333
// DefaultRules.
3434
var OnceAfterDefault = []Rule{
35+
{"remove_unnecessary_converts", removeUnnecessaryConverts},
3536
{"assign_catalog", assignCatalog},
3637
{"pushdown", pushdown},
3738
{"erase_projection", eraseProjection},

0 commit comments

Comments
 (0)