Skip to content

Commit 2fd708a

Browse files
authored
Ensure that schema search path is set with every connection on postgres (#14131) (#14216)
Backport #14131 Unfortunately every connection to postgres requires that the search path is set appropriately. This PR shadows the postgres driver to ensure that as soon as a connection is open, the search_path is set appropriately. Fix #14088 Signed-off-by: Andrew Thornton <[email protected]>
1 parent 7a0a133 commit 2fd708a

File tree

2 files changed

+85
-11
lines changed

2 files changed

+85
-11
lines changed

models/models.go

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,16 @@ func getEngine() (*xorm.Engine, error) {
145145
return nil, err
146146
}
147147

148-
engine, err := xorm.NewEngine(setting.Database.Type, connStr)
148+
var engine *xorm.Engine
149+
150+
if setting.Database.UsePostgreSQL && len(setting.Database.Schema) > 0 {
151+
// OK whilst we sort out our schema issues - create a schema aware postgres
152+
registerPostgresSchemaDriver()
153+
engine, err = xorm.NewEngine("postgresschema", connStr)
154+
} else {
155+
engine, err = xorm.NewEngine(setting.Database.Type, connStr)
156+
}
157+
149158
if err != nil {
150159
return nil, err
151160
}
@@ -155,16 +164,6 @@ func getEngine() (*xorm.Engine, error) {
155164
engine.Dialect().SetParams(map[string]string{"DEFAULT_VARCHAR": "nvarchar"})
156165
}
157166
engine.SetSchema(setting.Database.Schema)
158-
if setting.Database.UsePostgreSQL && len(setting.Database.Schema) > 0 {
159-
// Add the schema to the search path
160-
if _, err := engine.Exec(`SELECT set_config(
161-
'search_path',
162-
? || ',' || current_setting('search_path'),
163-
false)`,
164-
setting.Database.Schema); err != nil {
165-
return nil, err
166-
}
167-
}
168167
return engine, nil
169168
}
170169

models/sql_postgres_with_schema.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Copyright 2020 The Gitea Authors. All rights reserved.
2+
// Use of this source code is governed by a MIT-style
3+
// license that can be found in the LICENSE file.
4+
5+
package models
6+
7+
import (
8+
"database/sql"
9+
"database/sql/driver"
10+
"sync"
11+
12+
"code.gitea.io/gitea/modules/setting"
13+
14+
"github.com/lib/pq"
15+
"xorm.io/xorm/dialects"
16+
)
17+
18+
var registerOnce sync.Once
19+
20+
func registerPostgresSchemaDriver() {
21+
registerOnce.Do(func() {
22+
sql.Register("postgresschema", &postgresSchemaDriver{})
23+
dialects.RegisterDriver("postgresschema", dialects.QueryDriver("postgres"))
24+
})
25+
}
26+
27+
type postgresSchemaDriver struct {
28+
pq.Driver
29+
}
30+
31+
// Open opens a new connection to the database. name is a connection string.
32+
// This function opens the postgres connection in the default manner but immediately
33+
// runs set_config to set the search_path appropriately
34+
func (d *postgresSchemaDriver) Open(name string) (driver.Conn, error) {
35+
conn, err := d.Driver.Open(name)
36+
if err != nil {
37+
return conn, err
38+
}
39+
schemaValue, _ := driver.String.ConvertValue(setting.Database.Schema)
40+
41+
// golangci lint is incorrect here - there is no benefit to using driver.ExecerContext here
42+
// and in any case pq does not implement it
43+
if execer, ok := conn.(driver.Execer); ok { //nolint
44+
_, err := execer.Exec(`SELECT set_config(
45+
'search_path',
46+
$1 || ',' || current_setting('search_path'),
47+
false)`, []driver.Value{schemaValue}) //nolint
48+
if err != nil {
49+
_ = conn.Close()
50+
return nil, err
51+
}
52+
return conn, nil
53+
}
54+
55+
stmt, err := conn.Prepare(`SELECT set_config(
56+
'search_path',
57+
$1 || ',' || current_setting('search_path'),
58+
false)`)
59+
if err != nil {
60+
_ = conn.Close()
61+
return nil, err
62+
}
63+
defer stmt.Close()
64+
65+
// driver.String.ConvertValue will never return err for string
66+
67+
// golangci lint is incorrect here - there is no benefit to using stmt.ExecWithContext here
68+
_, err = stmt.Exec([]driver.Value{schemaValue}) //nolint
69+
if err != nil {
70+
_ = conn.Close()
71+
return nil, err
72+
}
73+
74+
return conn, nil
75+
}

0 commit comments

Comments
 (0)