Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

sql: add Databaser interface for cleaner analyzer rules #592

Merged
merged 1 commit into from
Jan 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sql/analyzer/assign_catalog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestAssignCatalog(t *testing.T) {

si, ok := node.(*plan.ShowIndexes)
require.True(ok)
require.Equal(db, si.Database)
require.Equal(db, si.Database())
require.Equal(c.IndexRegistry, si.Registry)

node, err = f.Apply(sql.NewEmptyContext(), a, plan.NewShowProcessList())
Expand Down
65 changes: 16 additions & 49 deletions sql/analyzer/resolve_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package analyzer

import (
"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/plan"
)

func resolveDatabase(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
Expand All @@ -12,59 +11,27 @@ func resolveDatabase(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error
a.Log("resolve database, node of type: %T", n)

return n.TransformUp(func(n sql.Node) (sql.Node, error) {
switch v := n.(type) {
case *plan.ShowIndexes:
db, err := a.Catalog.Database(a.Catalog.CurrentDatabase())
if err != nil {
return nil, err
}

nc := *v
nc.Database = db
return &nc, nil
case *plan.ShowTables:
var dbName = v.Database.Name()
if dbName == "" {
dbName = a.Catalog.CurrentDatabase()
}

db, err := a.Catalog.Database(dbName)
if err != nil {
return nil, err
}

nc := *v
nc.Database = db
return &nc, nil
case *plan.CreateTable:
db, err := a.Catalog.Database(a.Catalog.CurrentDatabase())
if err != nil {
return nil, err
}
d, ok := n.(sql.Databaser)
if !ok {
return n, nil
}

nc := *v
nc.Database = db
return &nc, nil
case *plan.Use:
db, err := a.Catalog.Database(v.Database.Name())
if err != nil {
return nil, err
var dbName = a.Catalog.CurrentDatabase()
if db := d.Database(); db != nil {
if _, ok := db.(sql.UnresolvedDatabase); !ok {
return n, nil
}

nc := *v
nc.Database = db
return &nc, nil
case *plan.ShowCreateDatabase:
db, err := a.Catalog.Database(v.Database.Name())
if err != nil {
return nil, err
if db.Name() != "" {
dbName = db.Name()
}
}

nc := *v
nc.Database = db
return &nc, nil
default:
return n, nil
db, err := a.Catalog.Database(dbName)
if err != nil {
return nil, err
}

return d.WithDatabase(db)
})
}
9 changes: 9 additions & 0 deletions sql/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,15 @@ type Expressioner interface {
TransformExpressions(TransformExprFunc) (Node, error)
}

// Databaser is a node that contains a reference to a database.
type Databaser interface {
// Database the current database.
Database() Database
// WithDatabase returns a new node instance with the database replaced with
// the one given as parameter.
WithDatabase(Database) (Node, error)
}

// Partition represents a partition from a SQL table.
type Partition interface {
Key() []byte
Expand Down
34 changes: 24 additions & 10 deletions sql/plan/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ var ErrCreateTable = errors.NewKind("tables cannot be created on database %s")

// CreateTable is a node describing the creation of some table.
type CreateTable struct {
Database sql.Database
name string
schema sql.Schema
db sql.Database
name string
schema sql.Schema
}

// NewCreateTable creates a new CreateTable node
Expand All @@ -22,23 +22,37 @@ func NewCreateTable(db sql.Database, name string, schema sql.Schema) *CreateTabl
}

return &CreateTable{
Database: db,
name: name,
schema: schema,
db: db,
name: name,
schema: schema,
}
}

var _ sql.Databaser = (*CreateTable)(nil)

// Database implements the sql.Databaser interface.
func (c *CreateTable) Database() sql.Database {
return c.db
}

// WithDatabase implements the sql.Databaser interface.
func (c *CreateTable) WithDatabase(db sql.Database) (sql.Node, error) {
nc := *c
nc.db = db
return &nc, nil
}

// Resolved implements the Resolvable interface.
func (c *CreateTable) Resolved() bool {
_, ok := c.Database.(sql.UnresolvedDatabase)
_, ok := c.db.(sql.UnresolvedDatabase)
return !ok
}

// RowIter implements the Node interface.
func (c *CreateTable) RowIter(s *sql.Context) (sql.RowIter, error) {
d, ok := c.Database.(sql.Alterable)
d, ok := c.db.(sql.Alterable)
if !ok {
return nil, ErrCreateTable.New(c.Database.Name())
return nil, ErrCreateTable.New(c.db.Name())
}

return sql.RowsToRowIter(), d.Create(c.name, c.schema)
Expand All @@ -52,7 +66,7 @@ func (c *CreateTable) Children() []sql.Node { return nil }

// TransformUp implements the Transformable interface.
func (c *CreateTable) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) {
return f(NewCreateTable(c.Database, c.name, c.schema))
return f(NewCreateTable(c.db, c.name, c.schema))
}

// TransformExpressionsUp implements the Transformable interface.
Expand Down
22 changes: 18 additions & 4 deletions sql/plan/show_create_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

// ShowCreateDatabase returns the SQL for creating a database.
type ShowCreateDatabase struct {
Database sql.Database
db sql.Database
IfNotExists bool
}

Expand All @@ -25,9 +25,23 @@ func NewShowCreateDatabase(db sql.Database, ifNotExists bool) *ShowCreateDatabas
return &ShowCreateDatabase{db, ifNotExists}
}

var _ sql.Databaser = (*ShowCreateDatabase)(nil)

// Database implements the sql.Databaser interface.
func (s *ShowCreateDatabase) Database() sql.Database {
return s.db
}

// WithDatabase implements the sql.Databaser interface.
func (s *ShowCreateDatabase) WithDatabase(db sql.Database) (sql.Node, error) {
nc := *s
nc.db = db
return &nc, nil
}

// RowIter implements the sql.Node interface.
func (s *ShowCreateDatabase) RowIter(ctx *sql.Context) (sql.RowIter, error) {
var name = s.Database.Name()
var name = s.db.Name()

var buf bytes.Buffer

Expand Down Expand Up @@ -56,15 +70,15 @@ func (s *ShowCreateDatabase) Schema() sql.Schema {
}

func (s *ShowCreateDatabase) String() string {
return fmt.Sprintf("SHOW CREATE DATABASE %s", s.Database.Name())
return fmt.Sprintf("SHOW CREATE DATABASE %s", s.db.Name())
}

// Children implements the sql.Node interface.
func (s *ShowCreateDatabase) Children() []sql.Node { return nil }

// Resolved implements the sql.Node interface.
func (s *ShowCreateDatabase) Resolved() bool {
_, ok := s.Database.(sql.UnresolvedDatabase)
_, ok := s.db.(sql.UnresolvedDatabase)
return !ok
}

Expand Down
22 changes: 18 additions & 4 deletions sql/plan/show_indexes.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

// ShowIndexes is a node that shows the indexes on a table.
type ShowIndexes struct {
Database sql.Database
db sql.Database
Table string
Registry *sql.IndexRegistry
}
Expand All @@ -19,15 +19,29 @@ func NewShowIndexes(db sql.Database, table string, registry *sql.IndexRegistry)
return &ShowIndexes{db, table, registry}
}

var _ sql.Databaser = (*ShowIndexes)(nil)

// Database implements the sql.Databaser interface.
func (n *ShowIndexes) Database() sql.Database {
return n.db
}

// WithDatabase implements the sql.Databaser interface.
func (n *ShowIndexes) WithDatabase(db sql.Database) (sql.Node, error) {
nc := *n
nc.db = db
return &nc, nil
}

// Resolved implements the Resolvable interface.
func (n *ShowIndexes) Resolved() bool {
_, ok := n.Database.(sql.UnresolvedDatabase)
_, ok := n.db.(sql.UnresolvedDatabase)
return !ok
}

// TransformUp implements the Transformable interface.
func (n *ShowIndexes) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) {
return f(NewShowIndexes(n.Database, n.Table, n.Registry))
return f(NewShowIndexes(n.db, n.Table, n.Registry))
}

// TransformExpressionsUp implements the Transformable interface.
Expand Down Expand Up @@ -67,7 +81,7 @@ func (n *ShowIndexes) Children() []sql.Node { return nil }
// RowIter implements the Node interface.
func (n *ShowIndexes) RowIter(*sql.Context) (sql.RowIter, error) {
return &showIndexesIter{
db: n.Database,
db: n.db,
table: n.Table,
registry: n.Registry,
}, nil
Expand Down
28 changes: 21 additions & 7 deletions sql/plan/show_tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (

// ShowTables is a node that shows the database tables.
type ShowTables struct {
Database sql.Database
Full bool
db sql.Database
Full bool
}

var showTablesSchema = sql.Schema{
Expand All @@ -24,14 +24,28 @@ var showTablesFullSchema = sql.Schema{
// NewShowTables creates a new show tables node given a database.
func NewShowTables(database sql.Database, full bool) *ShowTables {
return &ShowTables{
Database: database,
Full: full,
db: database,
Full: full,
}
}

var _ sql.Databaser = (*ShowTables)(nil)

// Database implements the sql.Databaser interface.
func (p *ShowTables) Database() sql.Database {
return p.db
}

// WithDatabase implements the sql.Databaser interface.
func (p *ShowTables) WithDatabase(db sql.Database) (sql.Node, error) {
nc := *p
nc.db = db
return &nc, nil
}

// Resolved implements the Resolvable interface.
func (p *ShowTables) Resolved() bool {
_, ok := p.Database.(sql.UnresolvedDatabase)
_, ok := p.db.(sql.UnresolvedDatabase)
return !ok
}

Expand All @@ -52,7 +66,7 @@ func (p *ShowTables) Schema() sql.Schema {
// RowIter implements the Node interface.
func (p *ShowTables) RowIter(ctx *sql.Context) (sql.RowIter, error) {
tableNames := []string{}
for key := range p.Database.Tables() {
for key := range p.db.Tables() {
tableNames = append(tableNames, key)
}

Expand All @@ -72,7 +86,7 @@ func (p *ShowTables) RowIter(ctx *sql.Context) (sql.RowIter, error) {

// TransformUp implements the Transformable interface.
func (p *ShowTables) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) {
return f(NewShowTables(p.Database, p.Full))
return f(NewShowTables(p.db, p.Full))
}

// TransformExpressionsUp implements the Transformable interface.
Expand Down
25 changes: 19 additions & 6 deletions sql/plan/use.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,36 @@ import (

// Use changes the current database.
type Use struct {
Database sql.Database
Catalog *sql.Catalog
db sql.Database
Catalog *sql.Catalog
}

// NewUse creates a new Use node.
func NewUse(db sql.Database) *Use {
return &Use{Database: db}
return &Use{db: db}
}

var _ sql.Node = (*Use)(nil)
var _ sql.Databaser = (*Use)(nil)

// Database implements the sql.Databaser interface.
func (u *Use) Database() sql.Database {
return u.db
}

// WithDatabase implements the sql.Databaser interface.
func (u *Use) WithDatabase(db sql.Database) (sql.Node, error) {
nc := *u
nc.db = db
return &nc, nil
}

// Children implements the sql.Node interface.
func (Use) Children() []sql.Node { return nil }

// Resolved implements the sql.Node interface.
func (u *Use) Resolved() bool {
_, ok := u.Database.(sql.UnresolvedDatabase)
_, ok := u.db.(sql.UnresolvedDatabase)
return !ok
}

Expand All @@ -33,7 +46,7 @@ func (Use) Schema() sql.Schema { return nil }

// RowIter implements the sql.Node interface.
func (u *Use) RowIter(ctx *sql.Context) (sql.RowIter, error) {
u.Catalog.SetCurrentDatabase(u.Database.Name())
u.Catalog.SetCurrentDatabase(u.db.Name())
return sql.RowsToRowIter(), nil
}

Expand All @@ -49,5 +62,5 @@ func (u *Use) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error)

// String implements the sql.Node interface.
func (u *Use) String() string {
return fmt.Sprintf("USE(%s)", u.Database.Name())
return fmt.Sprintf("USE(%s)", u.db.Name())
}