Skip to content

Commit d2d3460

Browse files
async/await prepared statements
1 parent 52d5636 commit d2d3460

File tree

6 files changed

+290
-0
lines changed

6 files changed

+290
-0
lines changed

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,27 @@ extension PostgresConnection {
460460
self.channel.write(task, promise: nil)
461461
}
462462
}
463+
464+
/// Execute a prepared statement, taking care of the preparation when necessary
465+
public func execute<P: PreparedStatement>(
466+
_ preparedStatement: P,
467+
logger: Logger
468+
) async throws -> AsyncThrowingMapSequence<PostgresRowSequence, P.Row>
469+
{
470+
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
471+
let task = HandlerTask.executePreparedStatement(.init(
472+
name: String(reflecting: P.self),
473+
sql: P.sql,
474+
bindings: preparedStatement.makeBindings(),
475+
logger: logger,
476+
promise: promise
477+
))
478+
self.channel.write(task, promise: nil)
479+
return try await promise.futureResult
480+
.map { $0.asyncSequence() }
481+
.get()
482+
.map { try preparedStatement.decodeRow($0) }
483+
}
463484
}
464485

465486
// MARK: EventLoopFuture interface
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import NIOCore
2+
3+
struct PreparedStatementStateMachine {
4+
enum State {
5+
case preparing([PreparedStatementContext])
6+
case prepared(RowDescription?)
7+
case error(PSQLError)
8+
}
9+
10+
enum Action {
11+
case prepareStatement
12+
case waitForAlreadyInFlightPreparation
13+
case executePendingStatements([PreparedStatementContext], RowDescription?)
14+
case returnError([PreparedStatementContext], PSQLError)
15+
}
16+
17+
var preparedStatements: [String: State]
18+
19+
init() {
20+
self.preparedStatements = [:]
21+
}
22+
23+
mutating func lookup(name: String, context: PreparedStatementContext) -> Action {
24+
if let state = self.preparedStatements[name] {
25+
switch state {
26+
case .preparing(var statements):
27+
statements.append(context)
28+
self.preparedStatements[name] = .preparing(statements)
29+
return .waitForAlreadyInFlightPreparation
30+
case .prepared(let rowDescription):
31+
return .executePendingStatements([context], rowDescription)
32+
case .error(let error):
33+
return .returnError([context], error)
34+
}
35+
} else {
36+
self.preparedStatements[name] = .preparing([context])
37+
return .prepareStatement
38+
}
39+
}
40+
41+
mutating func preparationComplete(
42+
name: String,
43+
rowDescription: RowDescription?
44+
) -> Action {
45+
guard case .preparing(let statements) = self.preparedStatements[name] else {
46+
preconditionFailure("Preparation completed for an unexpected statement")
47+
}
48+
// When sending the bindings we are going to ask for binary data.
49+
if var rowDescription {
50+
for i in 0..<rowDescription.columns.count {
51+
rowDescription.columns[i].format = .binary
52+
}
53+
self.preparedStatements[name] = .prepared(rowDescription)
54+
return .executePendingStatements(statements, rowDescription)
55+
} else {
56+
self.preparedStatements[name] = .prepared(nil)
57+
return .executePendingStatements(statements, nil)
58+
}
59+
}
60+
61+
mutating func errorHappened(name: String, error: PSQLError) -> Action {
62+
guard case .preparing(let statements) = self.preparedStatements[name] else {
63+
preconditionFailure("Preparation completed for an unexpected statement")
64+
}
65+
self.preparedStatements[name] = .error(error)
66+
return .returnError(statements, error)
67+
}
68+
}

Sources/PostgresNIO/New/PSQLTask.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ enum HandlerTask {
66
case closeCommand(CloseCommandContext)
77
case startListening(NotificationListener)
88
case cancelListening(String, Int)
9+
case executePreparedStatement(PreparedStatementContext)
910
}
1011

1112
enum PSQLTask {
@@ -69,6 +70,28 @@ final class ExtendedQueryContext {
6970
}
7071
}
7172

73+
final class PreparedStatementContext{
74+
let name: String
75+
let sql: String
76+
let bindings: PostgresBindings
77+
let logger: Logger
78+
let promise: EventLoopPromise<PSQLRowStream>
79+
80+
init(
81+
name: String,
82+
sql: String,
83+
bindings: PostgresBindings,
84+
logger: Logger,
85+
promise: EventLoopPromise<PSQLRowStream>
86+
) {
87+
self.name = name
88+
self.sql = sql
89+
self.bindings = bindings
90+
self.logger = logger
91+
self.promise = promise
92+
}
93+
}
94+
7295
final class CloseCommandContext {
7396
let target: CloseTarget
7497
let logger: Logger

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
2323
private let configureSSLCallback: ((Channel) throws -> Void)?
2424

2525
private var listenState: ListenStateMachine
26+
private var preparedStatementState: PreparedStatementStateMachine
2627

2728
init(
2829
configuration: PostgresConnection.InternalConfiguration,
@@ -33,6 +34,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
3334
self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData)
3435
self.eventLoop = eventLoop
3536
self.listenState = ListenStateMachine()
37+
self.preparedStatementState = PreparedStatementStateMachine()
3638
self.configuration = configuration
3739
self.configureSSLCallback = configureSSLCallback
3840
self.logger = logger
@@ -51,6 +53,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
5153
self.state = state
5254
self.eventLoop = eventLoop
5355
self.listenState = ListenStateMachine()
56+
self.preparedStatementState = PreparedStatementStateMachine()
5457
self.configuration = configuration
5558
self.configureSSLCallback = configureSSLCallback
5659
self.logger = logger
@@ -233,6 +236,56 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
233236
listener.failed(CancellationError())
234237
return
235238
}
239+
case .executePreparedStatement(let preparedStatement):
240+
switch self.preparedStatementState.lookup(
241+
name: preparedStatement.name,
242+
context: preparedStatement
243+
) {
244+
case .prepareStatement:
245+
let promise = self.eventLoop.makePromise(of: RowDescription?.self)
246+
promise.futureResult.whenSuccess { rowDescription in
247+
self.prepareStatementComplete(
248+
name: preparedStatement.name,
249+
rowDescription: rowDescription,
250+
context: context
251+
)
252+
}
253+
promise.futureResult.whenFailure { error in
254+
self.prepareStatementFailed(
255+
name: preparedStatement.name,
256+
error: error as! PSQLError,
257+
context: context
258+
)
259+
}
260+
psqlTask = .extendedQuery(.init(
261+
name: preparedStatement.name,
262+
query: preparedStatement.sql,
263+
logger: preparedStatement.logger,
264+
promise: promise
265+
))
266+
case .waitForAlreadyInFlightPreparation:
267+
// The state machine already keeps track of this
268+
// and will execute the statement as soon as it's prepared
269+
return
270+
case .executePendingStatements(let pendingStatements, let rowDescription):
271+
for statement in pendingStatements {
272+
let action = self.state.enqueue(task: .extendedQuery(.init(
273+
executeStatement: .init(
274+
name: statement.name,
275+
binds: statement.bindings,
276+
rowDescription: rowDescription),
277+
logger: statement.logger,
278+
promise: statement.promise
279+
)))
280+
self.run(action, with: context)
281+
}
282+
return
283+
case .returnError(let pendingStatements, let error):
284+
for statement in pendingStatements {
285+
statement.promise.fail(error)
286+
}
287+
return
288+
}
236289
}
237290

238291
let action = self.state.enqueue(task: psqlTask)
@@ -664,6 +717,49 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
664717
}
665718
}
666719

720+
private func prepareStatementComplete(
721+
name: String,
722+
rowDescription: RowDescription?,
723+
context: ChannelHandlerContext
724+
) {
725+
let action = self.preparedStatementState.preparationComplete(
726+
name: name,
727+
rowDescription: rowDescription
728+
)
729+
guard case .executePendingStatements(let statements, let rowDescription) = action else {
730+
preconditionFailure("Expected to have pending statements to execute")
731+
}
732+
for preparedStatement in statements {
733+
let action = self.state.enqueue(task: .extendedQuery(.init(
734+
executeStatement: .init(
735+
name: preparedStatement.name,
736+
binds: preparedStatement.bindings,
737+
rowDescription: rowDescription
738+
),
739+
logger: preparedStatement.logger,
740+
promise: preparedStatement.promise
741+
))
742+
)
743+
self.run(action, with: context)
744+
}
745+
}
746+
747+
private func prepareStatementFailed(
748+
name: String,
749+
error: PSQLError,
750+
context: ChannelHandlerContext
751+
) {
752+
let action = self.preparedStatementState.errorHappened(
753+
name: name,
754+
error: error
755+
)
756+
guard case .returnError(let statements, let error) = action else {
757+
preconditionFailure("Expected to have pending statements to execute")
758+
}
759+
for statement in statements {
760+
statement.promise.fail(error)
761+
}
762+
}
667763
}
668764

669765
extension PostgresChannelHandler: PSQLRowsDataSource {
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/// A prepared statement.
2+
///
3+
/// Structs conforming to this protocol will need to provide the SQL statement to
4+
/// send to the server and a way of creating bindings are decoding the result.
5+
///
6+
/// As an example, consider this struct:
7+
/// ```swift
8+
/// struct Example: PreparedStatement {
9+
/// static var sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1"
10+
/// typealias Row = (Int, String)
11+
///
12+
/// var state: String
13+
///
14+
/// func makeBindings() -> PostgresBindings {
15+
/// var bindings = PostgresBindings()
16+
/// bindings.append(.init(string: self.state))
17+
/// return bindings
18+
/// }
19+
///
20+
/// func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
21+
/// try row.decode(Row.self)
22+
/// }
23+
/// }
24+
/// ```
25+
///
26+
/// Structs conforming to this protocol can then be used with `PostgresConnection.execute(_ preparedStatement:, logger:)`,
27+
/// which will take care of preparing the statement on the server side and executing it.
28+
public protocol PreparedStatement {
29+
/// The type rows returned by the statement will be decoded into
30+
associatedtype Row
31+
32+
/// The SQL statement to prepare on the database server.
33+
static var sql: String { get }
34+
35+
/// Make the bindings to provided concrete values to use when executing the prepared SQL statement
36+
func makeBindings() -> PostgresBindings
37+
38+
/// Decode a row returned by the database into an instance of `Row`
39+
func decodeRow(_ row: PostgresRow) throws -> Row
40+
}

Tests/IntegrationTests/AsyncTests.swift

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,48 @@ final class AsyncPostgresConnectionTests: XCTestCase {
315315
try await connection.query("SELECT 1;", logger: .psqlTest)
316316
}
317317
}
318+
319+
func testPreparedStatement() async throws {
320+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
321+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
322+
let eventLoop = eventLoopGroup.next()
323+
324+
struct TestPreparedStatement: PreparedStatement {
325+
static var sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1"
326+
typealias Row = (Int, String)
327+
328+
var state: String
329+
330+
func makeBindings() -> PostgresBindings {
331+
var bindings = PostgresBindings()
332+
bindings.append(.init(string: self.state))
333+
return bindings
334+
}
335+
336+
func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
337+
try row.decode(Row.self)
338+
}
339+
}
340+
let preparedStatement = TestPreparedStatement(state: "active")
341+
try await withTestConnection(on: eventLoop) { connection in
342+
var results = try await connection.execute(preparedStatement, logger: .psqlTest)
343+
var counter = 0
344+
345+
for try await element in results {
346+
XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "test_database")
347+
counter += 1
348+
}
349+
350+
XCTAssertGreaterThanOrEqual(counter, 1)
351+
352+
// Second execution, which reuses the existing prepared statement
353+
results = try await connection.execute(preparedStatement, logger: .psqlTest)
354+
for try await element in results {
355+
XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "test_database")
356+
counter += 1
357+
}
358+
}
359+
}
318360
}
319361

320362
extension XCTestCase {

0 commit comments

Comments
 (0)