Skip to content

async/await prepared statement API #390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,72 @@ extension PostgresConnection {
self.channel.write(task, promise: nil)
}
}

/// Execute a prepared statement, taking care of the preparation when necessary
public func execute<Statement: PostgresPreparedStatement, Row>(
_ preparedStatement: Statement,
logger: Logger,
file: String = #fileID,
line: Int = #line
) async throws -> AsyncThrowingMapSequence<PostgresRowSequence, Row> where Row == Statement.Row {
let bindings = try preparedStatement.makeBindings()
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
let task = HandlerTask.executePreparedStatement(.init(
name: String(reflecting: Statement.self),
sql: Statement.sql,
bindings: bindings,
logger: logger,
promise: promise
))
self.channel.write(task, promise: nil)
do {
return try await promise.futureResult
.map { $0.asyncSequence() }
.get()
.map { try preparedStatement.decodeRow($0) }
} catch var error as PSQLError {
error.file = file
error.line = line
error.query = .init(
unsafeSQL: Statement.sql,
binds: bindings
)
throw error // rethrow with more metadata
}

}

/// Execute a prepared statement, taking care of the preparation when necessary
public func execute<Statement: PostgresPreparedStatement>(
_ preparedStatement: Statement,
logger: Logger,
file: String = #fileID,
line: Int = #line
) async throws -> String where Statement.Row == () {
let bindings = try preparedStatement.makeBindings()
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
let task = HandlerTask.executePreparedStatement(.init(
name: String(reflecting: Statement.self),
sql: Statement.sql,
bindings: bindings,
logger: logger,
promise: promise
))
self.channel.write(task, promise: nil)
do {
return try await promise.futureResult
.map { $0.commandTag }
.get()
} catch var error as PSQLError {
error.file = file
error.line = line
error.query = .init(
unsafeSQL: Statement.sql,
binds: bindings
)
throw error // rethrow with more metadata
}
}
}

// MARK: EventLoopFuture interface
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import NIOCore

struct PreparedStatementStateMachine {
enum State {
case preparing([PreparedStatementContext])
case prepared(RowDescription?)
case error(PSQLError)
}

var preparedStatements: [String: State] = [:]

enum LookupAction {
case prepareStatement
case waitForAlreadyInFlightPreparation
case executeStatement(RowDescription?)
case returnError(PSQLError)
}

mutating func lookup(preparedStatement: PreparedStatementContext) -> LookupAction {
if let state = self.preparedStatements[preparedStatement.name] {
switch state {
case .preparing(var statements):
statements.append(preparedStatement)
self.preparedStatements[preparedStatement.name] = .preparing(statements)
return .waitForAlreadyInFlightPreparation
case .prepared(let rowDescription):
return .executeStatement(rowDescription)
case .error(let error):
return .returnError(error)
}
} else {
self.preparedStatements[preparedStatement.name] = .preparing([preparedStatement])
return .prepareStatement
}
}

struct PreparationCompleteAction {
var statements: [PreparedStatementContext]
var rowDescription: RowDescription?
}

mutating func preparationComplete(
name: String,
rowDescription: RowDescription?
) -> PreparationCompleteAction {
guard let state = self.preparedStatements[name] else {
fatalError("Unknown prepared statement \(name)")
}
switch state {
case .preparing(let statements):
// When sending the bindings we are going to ask for binary data.
if var rowDescription = rowDescription {
for i in 0..<rowDescription.columns.count {
rowDescription.columns[i].format = .binary
}
self.preparedStatements[name] = .prepared(rowDescription)
return PreparationCompleteAction(
statements: statements,
rowDescription: rowDescription
)
} else {
self.preparedStatements[name] = .prepared(nil)
return PreparationCompleteAction(
statements: statements,
rowDescription: nil
)
}
case .prepared, .error:
preconditionFailure("Preparation completed happened in an unexpected state \(state)")
}
}

struct ErrorHappenedAction {
var statements: [PreparedStatementContext]
var error: PSQLError
}

mutating func errorHappened(name: String, error: PSQLError) -> ErrorHappenedAction {
guard let state = self.preparedStatements[name] else {
fatalError("Unknown prepared statement \(name)")
}
switch state {
case .preparing(let statements):
self.preparedStatements[name] = .error(error)
return ErrorHappenedAction(
statements: statements,
error: error
)
case .prepared, .error:
preconditionFailure("Error happened in an unexpected state \(state)")
}
}
}
23 changes: 23 additions & 0 deletions Sources/PostgresNIO/New/PSQLTask.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ enum HandlerTask {
case closeCommand(CloseCommandContext)
case startListening(NotificationListener)
case cancelListening(String, Int)
case executePreparedStatement(PreparedStatementContext)
}

enum PSQLTask {
Expand Down Expand Up @@ -69,6 +70,28 @@ final class ExtendedQueryContext {
}
}

final class PreparedStatementContext{
let name: String
let sql: String
let bindings: PostgresBindings
let logger: Logger
let promise: EventLoopPromise<PSQLRowStream>

init(
name: String,
sql: String,
bindings: PostgresBindings,
logger: Logger,
promise: EventLoopPromise<PSQLRowStream>
) {
self.name = name
self.sql = sql
self.bindings = bindings
self.logger = logger
self.promise = promise
}
}

final class CloseCommandContext {
let target: CloseTarget
let logger: Logger
Expand Down
115 changes: 112 additions & 3 deletions Sources/PostgresNIO/New/PostgresChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
private let configuration: PostgresConnection.InternalConfiguration
private let configureSSLCallback: ((Channel) throws -> Void)?

private var listenState: ListenStateMachine
private var listenState = ListenStateMachine()
private var preparedStatementState = PreparedStatementStateMachine()

init(
configuration: PostgresConnection.InternalConfiguration,
Expand All @@ -32,7 +33,6 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
) {
self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData)
self.eventLoop = eventLoop
self.listenState = ListenStateMachine()
self.configuration = configuration
self.configureSSLCallback = configureSSLCallback
self.logger = logger
Expand All @@ -50,7 +50,6 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
) {
self.state = state
self.eventLoop = eventLoop
self.listenState = ListenStateMachine()
self.configuration = configuration
self.configureSSLCallback = configureSSLCallback
self.logger = logger
Expand Down Expand Up @@ -233,6 +232,29 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
listener.failed(CancellationError())
return
}
case .executePreparedStatement(let preparedStatement):
let action = self.preparedStatementState.lookup(
preparedStatement: preparedStatement
)
switch action {
case .prepareStatement:
psqlTask = self.makePrepareStatementTask(
preparedStatement: preparedStatement,
context: context
)
case .waitForAlreadyInFlightPreparation:
// The state machine already keeps track of this
// and will execute the statement as soon as it's prepared
return
case .executeStatement(let rowDescription):
psqlTask = self.makeExecutePreparedStatementTask(
preparedStatement: preparedStatement,
rowDescription: rowDescription
)
case .returnError(let error):
preparedStatement.promise.fail(error)
return
}
}

let action = self.state.enqueue(task: psqlTask)
Expand Down Expand Up @@ -664,6 +686,93 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
}
}

private func makePrepareStatementTask(
preparedStatement: PreparedStatementContext,
context: ChannelHandlerContext
) -> PSQLTask {
let promise = self.eventLoop.makePromise(of: RowDescription?.self)
promise.futureResult.whenComplete { result in
switch result {
case .success(let rowDescription):
self.prepareStatementComplete(
name: preparedStatement.name,
rowDescription: rowDescription,
context: context
)
case .failure(let error):
let psqlError: PSQLError
if let error = error as? PSQLError {
psqlError = error
} else {
psqlError = .connectionError(underlying: error)
}
self.prepareStatementFailed(
name: preparedStatement.name,
error: psqlError,
context: context
)
}
}
return .extendedQuery(.init(
name: preparedStatement.name,
query: preparedStatement.sql,
logger: preparedStatement.logger,
promise: promise
))
}

private func makeExecutePreparedStatementTask(
preparedStatement: PreparedStatementContext,
rowDescription: RowDescription?
) -> PSQLTask {
return .extendedQuery(.init(
executeStatement: .init(
name: preparedStatement.name,
binds: preparedStatement.bindings,
rowDescription: rowDescription
),
logger: preparedStatement.logger,
promise: preparedStatement.promise
))
}

private func prepareStatementComplete(
name: String,
rowDescription: RowDescription?,
context: ChannelHandlerContext
) {
let action = self.preparedStatementState.preparationComplete(
name: name,
rowDescription: rowDescription
)
for preparedStatement in action.statements {
let action = self.state.enqueue(task: .extendedQuery(.init(
executeStatement: .init(
name: preparedStatement.name,
binds: preparedStatement.bindings,
rowDescription: action.rowDescription
),
logger: preparedStatement.logger,
promise: preparedStatement.promise
))
)
self.run(action, with: context)
}
}

private func prepareStatementFailed(
name: String,
error: PSQLError,
context: ChannelHandlerContext
) {
let action = self.preparedStatementState.errorHappened(
name: name,
error: error
)
for statement in action.statements {
statement.promise.fail(action.error)
}
}
}

extension PostgresChannelHandler: PSQLRowsDataSource {
Expand Down
10 changes: 10 additions & 0 deletions Sources/PostgresNIO/New/PostgresQuery.swift
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ public struct PostgresBindings: Sendable, Hashable {
self.metadata.append(.init(dataType: .null, format: .binary, protected: true))
}

@inlinable
public mutating func append<Value: PostgresEncodable>(_ value: Value) throws {
try self.append(value, context: .default)
}

@inlinable
public mutating func append<Value: PostgresEncodable, JSONEncoder: PostgresJSONEncoder>(
_ value: Value,
Expand All @@ -176,6 +181,11 @@ public struct PostgresBindings: Sendable, Hashable {
self.metadata.append(.init(value: value, protected: true))
}

@inlinable
public mutating func append<Value: PostgresNonThrowingEncodable>(_ value: Value) {
self.append(value, context: .default)
}

@inlinable
public mutating func append<Value: PostgresNonThrowingEncodable, JSONEncoder: PostgresJSONEncoder>(
_ value: Value,
Expand Down
Loading