Skip to content

Commit 4cf7f48

Browse files
Address PR feedbacks
1 parent d2d3460 commit 4cf7f48

File tree

5 files changed

+63
-23
lines changed

5 files changed

+63
-23
lines changed

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -462,15 +462,15 @@ extension PostgresConnection {
462462
}
463463

464464
/// Execute a prepared statement, taking care of the preparation when necessary
465-
public func execute<P: PreparedStatement>(
466-
_ preparedStatement: P,
465+
public func execute<Statement: PostgresPreparedStatement>(
466+
_ preparedStatement: Statement,
467467
logger: Logger
468-
) async throws -> AsyncThrowingMapSequence<PostgresRowSequence, P.Row>
468+
) async throws -> AsyncThrowingMapSequence<PostgresRowSequence, Statement.Row>
469469
{
470470
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
471471
let task = HandlerTask.executePreparedStatement(.init(
472-
name: String(reflecting: P.self),
473-
sql: P.sql,
472+
name: String(reflecting: Statement.self),
473+
sql: Statement.sql,
474474
bindings: preparedStatement.makeBindings(),
475475
logger: logger,
476476
promise: promise
@@ -481,6 +481,27 @@ extension PostgresConnection {
481481
.get()
482482
.map { try preparedStatement.decodeRow($0) }
483483
}
484+
485+
/// Execute a prepared statement, taking care of the preparation when necessary
486+
public func execute<Statement: PostgresPreparedStatement>(
487+
_ preparedStatement: Statement,
488+
logger: Logger
489+
) async throws -> String
490+
where Statement.Row == ()
491+
{
492+
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
493+
let task = HandlerTask.executePreparedStatement(.init(
494+
name: String(reflecting: Statement.self),
495+
sql: Statement.sql,
496+
bindings: preparedStatement.makeBindings(),
497+
logger: logger,
498+
promise: promise
499+
))
500+
self.channel.write(task, promise: nil)
501+
return try await promise.futureResult
502+
.map { $0.commandTag }
503+
.get()
504+
}
484505
}
485506

486507
// MARK: EventLoopFuture interface

Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,30 @@ struct PreparedStatementStateMachine {
66
case prepared(RowDescription?)
77
case error(PSQLError)
88
}
9+
10+
var preparedStatements: [String: State]
911

10-
enum Action {
12+
init() {
13+
self.preparedStatements = [:]
14+
}
15+
16+
enum LookupAction {
1117
case prepareStatement
1218
case waitForAlreadyInFlightPreparation
19+
case executeStatement(RowDescription?)
1320
case executePendingStatements([PreparedStatementContext], RowDescription?)
1421
case returnError([PreparedStatementContext], PSQLError)
1522
}
16-
17-
var preparedStatements: [String: State]
18-
19-
init() {
20-
self.preparedStatements = [:]
21-
}
22-
23-
mutating func lookup(name: String, context: PreparedStatementContext) -> Action {
23+
24+
mutating func lookup(name: String, context: PreparedStatementContext) -> LookupAction {
2425
if let state = self.preparedStatements[name] {
2526
switch state {
2627
case .preparing(var statements):
2728
statements.append(context)
2829
self.preparedStatements[name] = .preparing(statements)
2930
return .waitForAlreadyInFlightPreparation
3031
case .prepared(let rowDescription):
31-
return .executePendingStatements([context], rowDescription)
32+
return .executeStatement(rowDescription)
3233
case .error(let error):
3334
return .returnError([context], error)
3435
}
@@ -37,11 +38,15 @@ struct PreparedStatementStateMachine {
3738
return .prepareStatement
3839
}
3940
}
40-
41+
42+
enum PreparationCompleteAction {
43+
case executePendingStatements([PreparedStatementContext], RowDescription?)
44+
}
45+
4146
mutating func preparationComplete(
4247
name: String,
4348
rowDescription: RowDescription?
44-
) -> Action {
49+
) -> PreparationCompleteAction {
4550
guard case .preparing(let statements) = self.preparedStatements[name] else {
4651
preconditionFailure("Preparation completed for an unexpected statement")
4752
}
@@ -58,7 +63,11 @@ struct PreparedStatementStateMachine {
5863
}
5964
}
6065

61-
mutating func errorHappened(name: String, error: PSQLError) -> Action {
66+
enum ErrorHappenedAction {
67+
case returnError([PreparedStatementContext], PSQLError)
68+
}
69+
70+
mutating func errorHappened(name: String, error: PSQLError) -> ErrorHappenedAction {
6271
guard case .preparing(let statements) = self.preparedStatements[name] else {
6372
preconditionFailure("Preparation completed for an unexpected statement")
6473
}

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,11 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
237237
return
238238
}
239239
case .executePreparedStatement(let preparedStatement):
240-
switch self.preparedStatementState.lookup(
240+
let action = self.preparedStatementState.lookup(
241241
name: preparedStatement.name,
242242
context: preparedStatement
243-
) {
243+
)
244+
switch action {
244245
case .prepareStatement:
245246
let promise = self.eventLoop.makePromise(of: RowDescription?.self)
246247
promise.futureResult.whenSuccess { rowDescription in
@@ -267,6 +268,15 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
267268
// The state machine already keeps track of this
268269
// and will execute the statement as soon as it's prepared
269270
return
271+
case .executeStatement(let rowDescription):
272+
psqlTask = .extendedQuery(.init(
273+
executeStatement: .init(
274+
name: preparedStatement.name,
275+
binds: preparedStatement.bindings,
276+
rowDescription: rowDescription),
277+
logger: preparedStatement.logger,
278+
promise: preparedStatement.promise
279+
))
270280
case .executePendingStatements(let pendingStatements, let rowDescription):
271281
for statement in pendingStatements {
272282
let action = self.state.enqueue(task: .extendedQuery(.init(

Sources/PostgresNIO/New/PreparedStatement.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
///
66
/// As an example, consider this struct:
77
/// ```swift
8-
/// struct Example: PreparedStatement {
8+
/// struct Example: PostgresPreparedStatement {
99
/// static var sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1"
1010
/// typealias Row = (Int, String)
1111
///
@@ -25,7 +25,7 @@
2525
///
2626
/// Structs conforming to this protocol can then be used with `PostgresConnection.execute(_ preparedStatement:, logger:)`,
2727
/// which will take care of preparing the statement on the server side and executing it.
28-
public protocol PreparedStatement {
28+
public protocol PostgresPreparedStatement: Sendable {
2929
/// The type rows returned by the statement will be decoded into
3030
associatedtype Row
3131

Tests/IntegrationTests/AsyncTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ final class AsyncPostgresConnectionTests: XCTestCase {
321321
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
322322
let eventLoop = eventLoopGroup.next()
323323

324-
struct TestPreparedStatement: PreparedStatement {
324+
struct TestPreparedStatement: PostgresPreparedStatement {
325325
static var sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1"
326326
typealias Row = (Int, String)
327327

0 commit comments

Comments
 (0)