|
| 1 | +// Copyright 2020 The Gitea Authors. All rights reserved. |
| 2 | +// Use of this source code is governed by a MIT-style |
| 3 | +// license that can be found in the LICENSE file. |
| 4 | + |
| 5 | +package models |
| 6 | + |
| 7 | +import ( |
| 8 | + "database/sql" |
| 9 | + "database/sql/driver" |
| 10 | + "sync" |
| 11 | + |
| 12 | + "code.gitea.io/gitea/modules/setting" |
| 13 | + |
| 14 | + "github.com/lib/pq" |
| 15 | + "xorm.io/xorm/dialects" |
| 16 | +) |
| 17 | + |
| 18 | +var registerOnce sync.Once |
| 19 | + |
| 20 | +func registerPostgresSchemaDriver() { |
| 21 | + registerOnce.Do(func() { |
| 22 | + sql.Register("postgresschema", &postgresSchemaDriver{}) |
| 23 | + dialects.RegisterDriver("postgresschema", dialects.QueryDriver("postgres")) |
| 24 | + }) |
| 25 | +} |
| 26 | + |
| 27 | +type postgresSchemaDriver struct { |
| 28 | + pq.Driver |
| 29 | +} |
| 30 | + |
| 31 | +// Open opens a new connection to the database. name is a connection string. |
| 32 | +// This function opens the postgres connection in the default manner but immediately |
| 33 | +// runs set_config to set the search_path appropriately |
| 34 | +func (d *postgresSchemaDriver) Open(name string) (driver.Conn, error) { |
| 35 | + conn, err := d.Driver.Open(name) |
| 36 | + if err != nil { |
| 37 | + return conn, err |
| 38 | + } |
| 39 | + schemaValue, _ := driver.String.ConvertValue(setting.Database.Schema) |
| 40 | + |
| 41 | + // golangci lint is incorrect here - there is no benefit to using driver.ExecerContext here |
| 42 | + // and in any case pq does not implement it |
| 43 | + if execer, ok := conn.(driver.Execer); ok { //nolint |
| 44 | + _, err := execer.Exec(`SELECT set_config( |
| 45 | + 'search_path', |
| 46 | + $1 || ',' || current_setting('search_path'), |
| 47 | + false)`, []driver.Value{schemaValue}) //nolint |
| 48 | + if err != nil { |
| 49 | + _ = conn.Close() |
| 50 | + return nil, err |
| 51 | + } |
| 52 | + return conn, nil |
| 53 | + } |
| 54 | + |
| 55 | + stmt, err := conn.Prepare(`SELECT set_config( |
| 56 | + 'search_path', |
| 57 | + $1 || ',' || current_setting('search_path'), |
| 58 | + false)`) |
| 59 | + if err != nil { |
| 60 | + _ = conn.Close() |
| 61 | + return nil, err |
| 62 | + } |
| 63 | + defer stmt.Close() |
| 64 | + |
| 65 | + // driver.String.ConvertValue will never return err for string |
| 66 | + |
| 67 | + // golangci lint is incorrect here - there is no benefit to using stmt.ExecWithContext here |
| 68 | + _, err = stmt.Exec([]driver.Value{schemaValue}) //nolint |
| 69 | + if err != nil { |
| 70 | + _ = conn.Close() |
| 71 | + return nil, err |
| 72 | + } |
| 73 | + |
| 74 | + return conn, nil |
| 75 | +} |
0 commit comments