Skip to content

Commit 52d5636

Browse files
authored
close() closes immediately; Add new closeGracefully() (#383)
Fixes #370. This patch changes the behavior of `PostgresConnection.close()`. Currently `close()` terminates the connection only after all queued queries have been successfully processed by the server. This however leads to an unwanted dependency on the Postgres server to close a connection. If a server stops responding, the client is currently unable to close its connection. Because of this, this patch changes the behavior of `close()`. `close()` now terminates a connection immediately and fails all running or queued queries. To allow users to continue to use the existing behavior we introduce a `closeGracefully()` that now has the same behavior as close had previously. Since we never documented the old close behavior and we consider it dangerous in certain situations we are fine with changing the behavior without tagging a major version.
1 parent a5758b0 commit 52d5636

12 files changed

+263
-109
lines changed

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,17 @@ extension PostgresConnection {
384384
try await self.close().get()
385385
}
386386

387+
/// Closes the connection to the server, _after all queries_ that have been created on this connection have been run.
388+
public func closeGracefully() async throws {
389+
try await withTaskCancellationHandler { () async throws -> () in
390+
let promise = self.eventLoop.makePromise(of: Void.self)
391+
self.channel.triggerUserOutboundEvent(PSQLOutgoingEvent.gracefulShutdown, promise: promise)
392+
return try await promise.futureResult.get()
393+
} onCancel: {
394+
_ = self.close()
395+
}
396+
}
397+
387398
/// Run a query on the Postgres server the connection is connected to.
388399
///
389400
/// - Parameters:

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

Lines changed: 87 additions & 77 deletions
Large diffs are not rendered by default.

Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,14 @@ struct ListenStateMachine {
3636
}
3737

3838
mutating func stopListeningSucceeded(channel: String) -> StopListeningSuccessAction {
39-
return self.channels[channel, default: .init()].stopListeningSucceeded()
39+
switch self.channels[channel]!.stopListeningSucceeded() {
40+
case .none:
41+
self.channels.removeValue(forKey: channel)
42+
return .none
43+
44+
case .startListening:
45+
return .startListening
46+
}
4047
}
4148

4249
enum CancelAction {
@@ -46,7 +53,7 @@ struct ListenStateMachine {
4653
}
4754

4855
mutating func cancelNotificationListener(channel: String, id: Int) -> CancelAction {
49-
return self.channels[channel, default: .init()].cancelListening(id: id)
56+
return self.channels[channel]?.cancelListening(id: id) ?? .none
5057
}
5158

5259
mutating func fail(_ error: Error) -> [NotificationListener] {

Sources/PostgresNIO/New/PSQLError.swift

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ public struct PSQLError: Error {
1818

1919
case queryCancelled
2020
case tooManyParameters
21-
case connectionQuiescing
22-
case connectionClosed
21+
case clientClosesConnection
22+
case clientClosedConnection
23+
case serverClosedConnection
2324
case connectionError
2425
case uncleanShutdown
2526

@@ -45,13 +46,20 @@ public struct PSQLError: Error {
4546
public static let invalidCommandTag = Self(.invalidCommandTag)
4647
public static let queryCancelled = Self(.queryCancelled)
4748
public static let tooManyParameters = Self(.tooManyParameters)
48-
public static let connectionQuiescing = Self(.connectionQuiescing)
49-
public static let connectionClosed = Self(.connectionClosed)
49+
public static let clientClosesConnection = Self(.clientClosesConnection)
50+
public static let clientClosedConnection = Self(.clientClosedConnection)
51+
public static let serverClosedConnection = Self(.serverClosedConnection)
5052
public static let connectionError = Self(.connectionError)
5153
public static let uncleanShutdown = Self.init(.uncleanShutdown)
5254
public static let listenFailed = Self.init(.listenFailed)
5355
public static let unlistenFailed = Self.init(.unlistenFailed)
5456

57+
@available(*, deprecated, renamed: "clientClosesConnection")
58+
public static let connectionQuiescing = Self.clientClosesConnection
59+
60+
@available(*, deprecated, message: "Use the more specific `serverClosedConnection` or `clientClosedConnection` instead")
61+
public static let connectionClosed = Self.serverClosedConnection
62+
5563
public var description: String {
5664
switch self.base {
5765
case .sslUnsupported:
@@ -78,10 +86,12 @@ public struct PSQLError: Error {
7886
return "queryCancelled"
7987
case .tooManyParameters:
8088
return "tooManyParameters"
81-
case .connectionQuiescing:
82-
return "connectionQuiescing"
83-
case .connectionClosed:
84-
return "connectionClosed"
89+
case .clientClosesConnection:
90+
return "clientClosesConnection"
91+
case .clientClosedConnection:
92+
return "clientClosedConnection"
93+
case .serverClosedConnection:
94+
return "serverClosedConnection"
8595
case .connectionError:
8696
return "connectionError"
8797
case .uncleanShutdown:
@@ -377,19 +387,33 @@ public struct PSQLError: Error {
377387
return new
378388
}
379389

380-
static var connectionQuiescing: PSQLError { PSQLError(code: .connectionQuiescing) }
390+
static func clientClosesConnection(underlying: Error?) -> PSQLError {
391+
var error = PSQLError(code: .clientClosesConnection)
392+
error.underlying = underlying
393+
return error
394+
}
395+
396+
static func clientClosedConnection(underlying: Error?) -> PSQLError {
397+
var error = PSQLError(code: .clientClosedConnection)
398+
error.underlying = underlying
399+
return error
400+
}
381401

382-
static var connectionClosed: PSQLError { PSQLError(code: .connectionClosed) }
402+
static func serverClosedConnection(underlying: Error?) -> PSQLError {
403+
var error = PSQLError(code: .serverClosedConnection)
404+
error.underlying = underlying
405+
return error
406+
}
383407

384-
static var authMechanismRequiresPassword: PSQLError { PSQLError(code: .authMechanismRequiresPassword) }
408+
static let authMechanismRequiresPassword = PSQLError(code: .authMechanismRequiresPassword)
385409

386-
static var sslUnsupported: PSQLError { PSQLError(code: .sslUnsupported) }
410+
static let sslUnsupported = PSQLError(code: .sslUnsupported)
387411

388-
static var queryCancelled: PSQLError { PSQLError(code: .queryCancelled) }
412+
static let queryCancelled = PSQLError(code: .queryCancelled)
389413

390-
static var uncleanShutdown: PSQLError { PSQLError(code: .uncleanShutdown) }
414+
static let uncleanShutdown = PSQLError(code: .uncleanShutdown)
391415

392-
static var receivedUnencryptedDataAfterSSLRequest: PSQLError { PSQLError(code: .receivedUnencryptedDataAfterSSLRequest) }
416+
static let receivedUnencryptedDataAfterSSLRequest = PSQLError(code: .receivedUnencryptedDataAfterSSLRequest)
393417

394418
static func server(_ response: PostgresBackendMessage.ErrorResponse) -> PSQLError {
395419
var error = PSQLError(code: .server)

Sources/PostgresNIO/New/PSQLEventsHandler.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ enum PSQLOutgoingEvent {
77
///
88
/// this shall be removed with the next breaking change and always supplied with `PSQLConnection.Configuration`
99
case authenticate(AuthContext)
10+
11+
case gracefulShutdown
1012
}
1113

1214
enum PSQLEvent {

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
247247
return
248248
}
249249

250-
let action = self.state.close(promise)
250+
let action = self.state.close(promise: promise)
251251
self.run(action, with: context)
252252
}
253253

@@ -258,6 +258,11 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
258258
case PSQLOutgoingEvent.authenticate(let authContext):
259259
let action = self.state.provideAuthenticationContext(authContext)
260260
self.run(action, with: context)
261+
262+
case PSQLOutgoingEvent.gracefulShutdown:
263+
let action = self.state.gracefulClose(promise)
264+
self.run(action, with: context)
265+
261266
default:
262267
context.triggerUserOutboundEvent(event, promise: promise)
263268
}

Sources/PostgresNIO/Postgres+PSQLCompat.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ extension PSQLError {
3737
return self.underlying ?? self
3838
case .tooManyParameters, .invalidCommandTag:
3939
return self
40-
case .connectionQuiescing:
41-
return PostgresError.connectionClosed
42-
case .connectionClosed:
40+
case .clientClosesConnection,
41+
.clientClosedConnection,
42+
.serverClosedConnection:
4343
return PostgresError.connectionClosed
4444
case .connectionError:
4545
return self.underlying ?? self

Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,14 @@ class ConnectionStateMachineTests: XCTestCase {
137137

138138
func testErrorIsIgnoredWhenClosingConnection() {
139139
// test ignore unclean shutdown when closing connection
140-
var stateIgnoreChannelError = ConnectionStateMachine(.closing)
141-
140+
var stateIgnoreChannelError = ConnectionStateMachine(.closing(nil))
141+
142142
XCTAssertEqual(stateIgnoreChannelError.errorHappened(.connectionError(underlying: NIOSSLError.uncleanShutdown)), .wait)
143143
XCTAssertEqual(stateIgnoreChannelError.closed(), .fireChannelInactive)
144144

145145
// test ignore any other error when closing connection
146146

147-
var stateIgnoreErrorMessage = ConnectionStateMachine(.closing)
147+
var stateIgnoreErrorMessage = ConnectionStateMachine(.closing(nil))
148148
XCTAssertEqual(stateIgnoreErrorMessage.errorReceived(.init(fields: [:])), .wait)
149149
XCTAssertEqual(stateIgnoreErrorMessage.closed(), .fireChannelInactive)
150150
}

Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ final class PSQLRowStreamTests: XCTestCase {
2222

2323
func testFailedStream() {
2424
let stream = PSQLRowStream(
25-
source: .noRows(.failure(PSQLError.connectionClosed)),
25+
source: .noRows(.failure(PSQLError.serverClosedConnection(underlying: nil))),
2626
eventLoop: self.eventLoop,
2727
logger: self.logger
2828
)
2929

3030
XCTAssertThrowsError(try stream.all().wait()) {
31-
XCTAssertEqual($0 as? PSQLError, .connectionClosed)
31+
XCTAssertEqual($0 as? PSQLError, .serverClosedConnection(underlying: nil))
3232
}
3333
}
3434

Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@ class PostgresChannelHandlerTests: XCTestCase {
2424
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
2525
handler
2626
], loop: self.eventLoop)
27-
defer { XCTAssertNoThrow(try embedded.finish()) }
28-
27+
defer {
28+
do { try embedded.finish() }
29+
catch { print("\(String(reflecting: error))") }
30+
}
31+
2932
var maybeMessage: PostgresFrontendMessage?
3033
XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil))
3134
XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self))

Tests/PostgresNIOTests/New/PostgresConnectionTests.swift

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,98 @@ class PostgresConnectionTests: XCTestCase {
182182
}
183183
}
184184

185+
func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws {
186+
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
187+
try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in
188+
for _ in 1...2 {
189+
taskGroup.addTask {
190+
let rows = try await connection.query("SELECT 1;", logger: self.logger)
191+
var iterator = rows.decode(Int.self).makeAsyncIterator()
192+
let first = try await iterator.next()
193+
XCTAssertEqual(first, 1)
194+
let second = try await iterator.next()
195+
XCTAssertNil(second)
196+
}
197+
}
198+
199+
for i in 0...1 {
200+
let listenMessage = try await channel.waitForUnpreparedRequest()
201+
XCTAssertEqual(listenMessage.parse.query, "SELECT 1;")
202+
203+
if i == 0 {
204+
taskGroup.addTask {
205+
try await connection.closeGracefully()
206+
}
207+
}
208+
209+
try await channel.writeInbound(PostgresBackendMessage.parseComplete)
210+
try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: [])))
211+
let intDescription = RowDescription.Column(
212+
name: "",
213+
tableOID: 0,
214+
columnAttributeNumber: 0,
215+
dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary
216+
)
217+
try await channel.writeInbound(PostgresBackendMessage.rowDescription(.init(columns: [intDescription])))
218+
try await channel.testingEventLoop.executeInContext { channel.read() }
219+
try await channel.writeInbound(PostgresBackendMessage.bindComplete)
220+
try await channel.testingEventLoop.executeInContext { channel.read() }
221+
try await channel.writeInbound(PostgresBackendMessage.dataRow([Int(1)]))
222+
try await channel.testingEventLoop.executeInContext { channel.read() }
223+
try await channel.writeInbound(PostgresBackendMessage.commandComplete("SELECT 1 1"))
224+
try await channel.testingEventLoop.executeInContext { channel.read() }
225+
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))
226+
}
227+
228+
let terminate = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self)
229+
XCTAssertEqual(terminate, .terminate)
230+
try await channel.closeFuture.get()
231+
XCTAssertEqual(channel.isActive, false)
232+
233+
while let taskResult = await taskGroup.nextResult() {
234+
switch taskResult {
235+
case .success:
236+
break
237+
case .failure(let failure):
238+
XCTFail("Unexpected error: \(failure)")
239+
}
240+
}
241+
}
242+
}
243+
244+
func testCloseClosesImmediatly() async throws {
245+
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
246+
247+
try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in
248+
for _ in 1...2 {
249+
taskGroup.addTask {
250+
try await connection.query("SELECT 1;", logger: self.logger)
251+
}
252+
}
253+
254+
let listenMessage = try await channel.waitForUnpreparedRequest()
255+
XCTAssertEqual(listenMessage.parse.query, "SELECT 1;")
256+
257+
async let close: () = connection.close()
258+
259+
try await channel.closeFuture.get()
260+
XCTAssertEqual(channel.isActive, false)
261+
262+
try await close
263+
264+
while let taskResult = await taskGroup.nextResult() {
265+
switch taskResult {
266+
case .success:
267+
XCTFail("Expected queries to fail")
268+
case .failure(let failure):
269+
guard let error = failure as? PSQLError else {
270+
return XCTFail("Unexpected error type: \(failure)")
271+
}
272+
XCTAssertEqual(error.code, .clientClosedConnection)
273+
}
274+
}
275+
}
276+
}
185277

186278
func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) {
187279
let eventLoop = NIOAsyncTestingEventLoop()

Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ final class PostgresRowSequenceTests: XCTestCase {
183183
logger: self.logger
184184
)
185185

186-
stream.receive(completion: .failure(PSQLError.connectionClosed))
186+
stream.receive(completion: .failure(PSQLError.serverClosedConnection(underlying: nil)))
187187

188188
let rowSequence = stream.asyncSequence()
189189

@@ -194,7 +194,7 @@ final class PostgresRowSequenceTests: XCTestCase {
194194
}
195195
XCTFail("Expected that an error was thrown before.")
196196
} catch {
197-
XCTAssertEqual(error as? PSQLError, .connectionClosed)
197+
XCTAssertEqual(error as? PSQLError, .serverClosedConnection(underlying: nil))
198198
}
199199
}
200200

@@ -255,14 +255,14 @@ final class PostgresRowSequenceTests: XCTestCase {
255255
XCTAssertEqual(try row1?.decode(Int.self, context: .default), 0)
256256

257257
DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) {
258-
stream.receive(completion: .failure(PSQLError.connectionClosed))
258+
stream.receive(completion: .failure(PSQLError.serverClosedConnection(underlying: nil)))
259259
}
260260

261261
do {
262262
_ = try await rowIterator.next()
263263
XCTFail("Expected that an error was thrown before.")
264264
} catch {
265-
XCTAssertEqual(error as? PSQLError, .connectionClosed)
265+
XCTAssertEqual(error as? PSQLError, .serverClosedConnection(underlying: nil))
266266
}
267267
}
268268

0 commit comments

Comments
 (0)