|
| 1 | +// Copyright 2019 The Gitea Authors. All rights reserved. |
| 2 | +// Use of this source code is governed by a MIT-style |
| 3 | +// license that can be found in the LICENSE file. |
| 4 | + |
| 5 | +package graceful |
| 6 | + |
| 7 | +import ( |
| 8 | + "fmt" |
| 9 | + "net" |
| 10 | + "os" |
| 11 | + "os/exec" |
| 12 | + "strconv" |
| 13 | + "strings" |
| 14 | + "sync" |
| 15 | +) |
| 16 | + |
| 17 | +const ( |
| 18 | + listenFDs = "LISTEN_FDS" |
| 19 | + startFD = 3 |
| 20 | +) |
| 21 | + |
| 22 | +// In order to keep the working directory the same as when we started we record |
| 23 | +// it at startup. |
| 24 | +var originalWD, _ = os.Getwd() |
| 25 | + |
| 26 | +var ( |
| 27 | + once = sync.Once{} |
| 28 | + mutex = sync.Mutex{} |
| 29 | + |
| 30 | + providedListeners = []net.Listener{} |
| 31 | + activeListeners = []net.Listener{} |
| 32 | +) |
| 33 | + |
| 34 | +func getProvidedFDs() (savedErr error) { |
| 35 | + once.Do(func() { |
| 36 | + mutex.Lock() |
| 37 | + defer mutex.Unlock() |
| 38 | + numFDs := os.Getenv(listenFDs) |
| 39 | + if numFDs == "" { |
| 40 | + return |
| 41 | + } |
| 42 | + n, err := strconv.Atoi(numFDs) |
| 43 | + if err != nil { |
| 44 | + savedErr = fmt.Errorf("%s is not a number: %s. Err: %v", listenFDs, numFDs, err) |
| 45 | + return |
| 46 | + } |
| 47 | + for i := startFD; i < n+startFD; i++ { |
| 48 | + file := os.NewFile(uintptr(i), "listener") |
| 49 | + |
| 50 | + l, err := net.FileListener(file) |
| 51 | + if err == nil { |
| 52 | + if err = file.Close(); err != nil { |
| 53 | + savedErr = fmt.Errorf("error closing provided socket fd %d: %s", i, err) |
| 54 | + return |
| 55 | + } |
| 56 | + providedListeners = append(providedListeners, l) |
| 57 | + continue |
| 58 | + } |
| 59 | + // If needed we can handle packetconns here. |
| 60 | + savedErr = fmt.Errorf("Error getting provided socket fd %d: %v", i, err) |
| 61 | + return |
| 62 | + } |
| 63 | + }) |
| 64 | + return savedErr |
| 65 | +} |
| 66 | + |
| 67 | +// GetListener obtains a listener for the local network address. The network must be |
| 68 | +// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket". It |
| 69 | +// returns an provided net.Listener for the matching network and address, or |
| 70 | +// creates a new one using net.Listen. |
| 71 | +func GetListener(network, address string) (net.Listener, error) { |
| 72 | + switch network { |
| 73 | + default: |
| 74 | + return nil, net.UnknownNetworkError(network) |
| 75 | + case "tcp", "tcp4", "tcp6": |
| 76 | + tcpAddr, err := net.ResolveTCPAddr(network, address) |
| 77 | + if err != nil { |
| 78 | + return nil, err |
| 79 | + } |
| 80 | + return GetListenerTCP(network, tcpAddr) |
| 81 | + case "unix", "unixpacket", "invalid_unix_net_for_test": |
| 82 | + unixAddr, err := net.ResolveUnixAddr(network, address) |
| 83 | + if err != nil { |
| 84 | + return nil, err |
| 85 | + } |
| 86 | + return GetListenerUnix(network, unixAddr) |
| 87 | + } |
| 88 | +} |
| 89 | + |
| 90 | +// GetListenerTCP announces on the local network address. The network must be: |
| 91 | +// "tcp", "tcp4" or "tcp6". It returns a provided net.Listener for the |
| 92 | +// matching network and address, or creates a new one using net.ListenTCP. |
| 93 | +func GetListenerTCP(network string, address *net.TCPAddr) (*net.TCPListener, error) { |
| 94 | + if err := getProvidedFDs(); err != nil { |
| 95 | + return nil, err |
| 96 | + } |
| 97 | + |
| 98 | + mutex.Lock() |
| 99 | + defer mutex.Unlock() |
| 100 | + |
| 101 | + // look for a provided listener |
| 102 | + for i, l := range providedListeners { |
| 103 | + if isSameAddr(l.Addr(), address) { |
| 104 | + providedListeners = append(providedListeners[:i], providedListeners[i+1:]...) |
| 105 | + |
| 106 | + activeListeners = append(activeListeners, l) |
| 107 | + return l.(*net.TCPListener), nil |
| 108 | + } |
| 109 | + } |
| 110 | + |
| 111 | + // make a fresh listener |
| 112 | + l, err := net.ListenTCP(network, address) |
| 113 | + if err != nil { |
| 114 | + return nil, err |
| 115 | + } |
| 116 | + activeListeners = append(activeListeners, l) |
| 117 | + return l, nil |
| 118 | +} |
| 119 | + |
| 120 | +// GetListenerUnix announces on the local network address. The network must be: |
| 121 | +// "unix" or "unixpacket". It returns a provided net.Listener for the |
| 122 | +// matching network and address, or creates a new one using net.ListenUnix. |
| 123 | +func GetListenerUnix(network string, address *net.UnixAddr) (*net.UnixListener, error) { |
| 124 | + if err := getProvidedFDs(); err != nil { |
| 125 | + return nil, err |
| 126 | + } |
| 127 | + |
| 128 | + mutex.Lock() |
| 129 | + defer mutex.Unlock() |
| 130 | + |
| 131 | + // look for a provided listener |
| 132 | + for i, l := range providedListeners { |
| 133 | + if isSameAddr(l.Addr(), address) { |
| 134 | + providedListeners = append(providedListeners[:i], providedListeners[i+1:]...) |
| 135 | + activeListeners = append(activeListeners, l) |
| 136 | + return l.(*net.UnixListener), nil |
| 137 | + } |
| 138 | + } |
| 139 | + |
| 140 | + // make a fresh listener |
| 141 | + l, err := net.ListenUnix(network, address) |
| 142 | + if err != nil { |
| 143 | + return nil, err |
| 144 | + } |
| 145 | + activeListeners = append(activeListeners, l) |
| 146 | + return l, nil |
| 147 | +} |
| 148 | + |
| 149 | +func isSameAddr(a1, a2 net.Addr) bool { |
| 150 | + // If the addresses are not on the same network fail. |
| 151 | + if a1.Network() != a2.Network() { |
| 152 | + return false |
| 153 | + } |
| 154 | + |
| 155 | + // If the two addresses have the same string representation they're equal |
| 156 | + a1s := a1.String() |
| 157 | + a2s := a2.String() |
| 158 | + if a1s == a2s { |
| 159 | + return true |
| 160 | + } |
| 161 | + |
| 162 | + // This allows for ipv6 vs ipv4 local addresses to compare as equal. This |
| 163 | + // scenario is common when listening on localhost. |
| 164 | + const ipv6prefix = "[::]" |
| 165 | + a1s = strings.TrimPrefix(a1s, ipv6prefix) |
| 166 | + a2s = strings.TrimPrefix(a2s, ipv6prefix) |
| 167 | + const ipv4prefix = "0.0.0.0" |
| 168 | + a1s = strings.TrimPrefix(a1s, ipv4prefix) |
| 169 | + a2s = strings.TrimPrefix(a2s, ipv4prefix) |
| 170 | + return a1s == a2s |
| 171 | +} |
| 172 | + |
| 173 | +func getActiveListeners() []net.Listener { |
| 174 | + mutex.Lock() |
| 175 | + defer mutex.Unlock() |
| 176 | + listeners := make([]net.Listener, len(activeListeners)) |
| 177 | + copy(listeners, activeListeners) |
| 178 | + return listeners |
| 179 | +} |
| 180 | + |
| 181 | +// RestartProcess starts a new process passing it the active listeners. It |
| 182 | +// doesn't fork, but starts a new process using the same environment and |
| 183 | +// arguments as when it was originally started. This allows for a newly |
| 184 | +// deployed binary to be started. It returns the pid of the newly started |
| 185 | +// process when successful. |
| 186 | +func RestartProcess() (int, error) { |
| 187 | + listeners := getActiveListeners() |
| 188 | + |
| 189 | + // Extract the fds from the listeners. |
| 190 | + files := make([]*os.File, len(listeners)) |
| 191 | + for i, l := range listeners { |
| 192 | + var err error |
| 193 | + // Now, all our listeners actually have File() functions so instead of |
| 194 | + // individually casting we just use a hacky interface |
| 195 | + files[i], err = l.(filer).File() |
| 196 | + if err != nil { |
| 197 | + return 0, err |
| 198 | + } |
| 199 | + // Remember to close these at the end. |
| 200 | + defer files[i].Close() |
| 201 | + } |
| 202 | + |
| 203 | + // Use the original binary location. This works with symlinks such that if |
| 204 | + // the file it points to has been changed we will use the updated symlink. |
| 205 | + argv0, err := exec.LookPath(os.Args[0]) |
| 206 | + if err != nil { |
| 207 | + return 0, err |
| 208 | + } |
| 209 | + |
| 210 | + // Pass on the environment and replace the old count key with the new one. |
| 211 | + var env []string |
| 212 | + for _, v := range os.Environ() { |
| 213 | + if !strings.HasPrefix(v, listenFDs+"=") { |
| 214 | + env = append(env, v) |
| 215 | + } |
| 216 | + } |
| 217 | + env = append(env, fmt.Sprintf("%s=%d", listenFDs, len(listeners))) |
| 218 | + |
| 219 | + allFiles := append([]*os.File{os.Stdin, os.Stdout, os.Stderr}, files...) |
| 220 | + process, err := os.StartProcess(argv0, os.Args, &os.ProcAttr{ |
| 221 | + Dir: originalWD, |
| 222 | + Env: env, |
| 223 | + Files: allFiles, |
| 224 | + }) |
| 225 | + if err != nil { |
| 226 | + return 0, err |
| 227 | + } |
| 228 | + return process.Pid, nil |
| 229 | +} |
| 230 | + |
| 231 | +type filer interface { |
| 232 | + File() (*os.File, error) |
| 233 | +} |
0 commit comments