Skip to content

Commit ce06b45

Browse files
committed
Fixes bi-directional streaming
Motivation: When we stream request body, current implementation expects that body will finish streaming _before_ we start to receive response body parts. This is not correct, reponse body parts can start to arrive before we finish sending the request. Modifications: - Simplifies state machine, we only case about request being fully sent to prevent sending body parts after .end, but response state machine is mostly ignored and correct flow will be handled by NIOHTTP1 pipeline - Adds HTTPEchoHandler, that replies to each response body part - Adds bi-directional streaming test Result: Closes #327
1 parent ba845ee commit ce06b45

File tree

5 files changed

+95
-34
lines changed

5 files changed

+95
-34
lines changed

Sources/AsyncHTTPClient/HTTPHandler.swift

+23-17
Original file line numberDiff line numberDiff line change
@@ -664,12 +664,12 @@ internal struct TaskCancelEvent {}
664664

665665
internal class TaskHandler<Delegate: HTTPClientResponseDelegate>: RemovableChannelHandler {
666666
enum State {
667-
case idle
668-
case bodySent
669-
case sent
670-
case head
667+
case inactive
668+
case sending_waiting
669+
case sending_head
670+
case sent_waiting
671+
case sent_head
671672
case redirected(HTTPResponseHead, URL)
672-
case body
673673
case endOrError
674674
}
675675

@@ -679,7 +679,7 @@ internal class TaskHandler<Delegate: HTTPClientResponseDelegate>: RemovableChann
679679
let ignoreUncleanSSLShutdown: Bool
680680
let logger: Logger // We are okay to store the logger here because a TaskHandler is just for one request.
681681

682-
var state: State = .idle
682+
var state: State = .inactive
683683
var expectedBodyLength: Int?
684684
var actualBodyLength: Int = 0
685685
var pendingRead = false
@@ -794,7 +794,8 @@ extension TaskHandler: ChannelDuplexHandler {
794794
typealias OutboundOut = HTTPClientRequestPart
795795

796796
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
797-
self.state = .idle
797+
self.state = .sending_waiting
798+
798799
let request = self.unwrapOutboundIn(data)
799800

800801
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1),
@@ -840,23 +841,27 @@ extension TaskHandler: ChannelDuplexHandler {
840841
self.writeBody(request: request, context: context)
841842
}.flatMap {
842843
context.eventLoop.assertInEventLoop()
844+
843845
if case .endOrError = self.state {
844846
return context.eventLoop.makeSucceededFuture(())
847+
} else if case .sending_waiting = self.state {
848+
self.state = .sent_waiting
849+
} else if case .sending_head = self.state {
850+
self.state = .sent_head
845851
}
846852

847-
self.state = .bodySent
848853
if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength {
849854
let error = HTTPClientError.bodyLengthMismatch
850855
return context.eventLoop.makeFailedFuture(error)
851856
}
852857
return context.writeAndFlush(self.wrapOutboundOut(.end(nil)))
853858
}.map {
854859
context.eventLoop.assertInEventLoop()
860+
855861
if case .endOrError = self.state {
856862
return
857863
}
858864

859-
self.state = .sent
860865
self.callOutToDelegateFireAndForget(self.delegate.didSendRequest)
861866
}.flatMapErrorThrowing { error in
862867
context.eventLoop.assertInEventLoop()
@@ -902,7 +907,7 @@ extension TaskHandler: ChannelDuplexHandler {
902907

903908
private func writeBodyPart(context: ChannelHandlerContext, part: IOData, promise: EventLoopPromise<Void>) {
904909
switch self.state {
905-
case .idle:
910+
case .sending_waiting, .sending_head:
906911
if let limit = self.expectedBodyLength, self.actualBodyLength + part.readableBytes > limit {
907912
let error = HTTPClientError.bodyLengthMismatch
908913
self.errorCaught(context: context, error: error)
@@ -933,6 +938,10 @@ extension TaskHandler: ChannelDuplexHandler {
933938
case .head(let head):
934939
if case .endOrError = self.state {
935940
return
941+
} else if case .sending_waiting = self.state {
942+
self.state = .sending_head
943+
} else if case .sent_waiting = self.state {
944+
self.state = .sent_head
936945
}
937946

938947
if !head.isKeepAlive {
@@ -942,7 +951,6 @@ extension TaskHandler: ChannelDuplexHandler {
942951
if let redirectURL = self.redirectHandler?.redirectTarget(status: head.status, headers: head.headers) {
943952
self.state = .redirected(head, redirectURL)
944953
} else {
945-
self.state = .head
946954
self.mayRead = false
947955
self.callOutToDelegate(value: head, channelEventLoop: context.eventLoop, self.delegate.didReceiveHead)
948956
.whenComplete { result in
@@ -954,7 +962,6 @@ extension TaskHandler: ChannelDuplexHandler {
954962
case .redirected, .endOrError:
955963
break
956964
default:
957-
self.state = .body
958965
self.mayRead = false
959966
self.callOutToDelegate(value: body, channelEventLoop: context.eventLoop, self.delegate.didReceiveBodyPart)
960967
.whenComplete { result in
@@ -1009,10 +1016,10 @@ extension TaskHandler: ChannelDuplexHandler {
10091016

10101017
func channelInactive(context: ChannelHandlerContext) {
10111018
switch self.state {
1019+
case .inactive, .sending_waiting, .sending_head, .sent_waiting, .sent_head, .redirected:
1020+
self.errorCaught(context: context, error: HTTPClientError.remoteConnectionClosed)
10121021
case .endOrError:
10131022
break
1014-
case .body, .head, .idle, .redirected, .sent, .bodySent:
1015-
self.errorCaught(context: context, error: HTTPClientError.remoteConnectionClosed)
10161023
}
10171024
context.fireChannelInactive()
10181025
}
@@ -1025,8 +1032,7 @@ extension TaskHandler: ChannelDuplexHandler {
10251032
/// Some HTTP Servers can 'forget' to respond with CloseNotify when client is closing connection,
10261033
/// this could lead to incomplete SSL shutdown. But since request is already processed, we can ignore this error.
10271034
break
1028-
case .head where self.ignoreUncleanSSLShutdown,
1029-
.body where self.ignoreUncleanSSLShutdown:
1035+
case .sent_head where self.ignoreUncleanSSLShutdown:
10301036
/// We can also ignore this error like `.end`.
10311037
break
10321038
default:
@@ -1035,7 +1041,7 @@ extension TaskHandler: ChannelDuplexHandler {
10351041
}
10361042
default:
10371043
switch self.state {
1038-
case .idle, .bodySent, .sent, .head, .redirected, .body:
1044+
case .inactive, .sending_waiting, .sending_head, .sent_waiting, .sent_head, .redirected:
10391045
self.state = .endOrError
10401046
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
10411047
case .endOrError:

Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class HTTPClientInternalTests: XCTestCase {
145145

146146
try channel.pipeline.addHandler(handler).wait()
147147

148-
handler.state = .sent
148+
handler.state = .sent_waiting
149149
var body = channel.allocator.buffer(capacity: 4)
150150
body.writeStaticString("1234")
151151

Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift

+24
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,30 @@ struct CollectEverythingLogHandler: LogHandler {
896896
}
897897
}
898898

899+
class HTTPEchoHandler: ChannelInboundHandler {
900+
typealias InboundIn = HTTPServerRequestPart
901+
typealias OutboundOut = HTTPServerResponsePart
902+
903+
var promises: CircularBuffer<EventLoopPromise<Void>> = CircularBuffer()
904+
905+
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
906+
let request = self.unwrapInboundIn(data)
907+
switch request {
908+
case .head:
909+
context.writeAndFlush(self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), promise: nil)
910+
case .body(let bytes):
911+
context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(bytes)))).whenSuccess {
912+
if let promise = self.promises.popFirst() {
913+
promise.succeed(())
914+
}
915+
}
916+
case .end:
917+
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
918+
context.close(promise: nil)
919+
}
920+
}
921+
}
922+
899923
private let cert = """
900924
-----BEGIN CERTIFICATE-----
901925
MIICmDCCAYACCQCPC8JDqMh1zzANBgkqhkiG9w0BAQsFADANMQswCQYDVQQGEwJ1

Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift

+1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ extension HTTPClientTests {
128128
("testSSLHandshakeErrorPropagation", testSSLHandshakeErrorPropagation),
129129
("testSSLHandshakeErrorPropagationDelayedClose", testSSLHandshakeErrorPropagationDelayedClose),
130130
("testWeCloseConnectionsWhenConnectionCloseSetByServer", testWeCloseConnectionsWhenConnectionCloseSetByServer),
131+
("testBiDirectionalStreaming", testBiDirectionalStreaming),
131132
]
132133
}
133134
}

Tests/AsyncHTTPClientTests/HTTPClientTests.swift

+46-16
Original file line numberDiff line numberDiff line change
@@ -2708,15 +2708,12 @@ class HTTPClientTests: XCTestCase {
27082708
let task = client.execute(request: request, delegate: TestHTTPDelegate())
27092709

27102710
XCTAssertThrowsError(try task.wait()) { error in
2711-
#if os(Linux)
2711+
if isTestingNIOTS() {
2712+
XCTAssertEqual(error as? ChannelError, .connectTimeout(.milliseconds(100)))
2713+
} else {
2714+
//XCTAssertEqual((error as? IOError).map { $0.errnoCode }, ECONNRESET)
27122715
XCTAssertEqual(error as? NIOSSLError, NIOSSLError.uncleanShutdown)
2713-
#else
2714-
if isTestingNIOTS() {
2715-
XCTAssertEqual(error as? ChannelError, .connectTimeout(.milliseconds(100)))
2716-
} else {
2717-
XCTAssertEqual((error as? IOError).map { $0.errnoCode }, ECONNRESET)
2718-
}
2719-
#endif
2716+
}
27202717
}
27212718
}
27222719

@@ -2756,15 +2753,12 @@ class HTTPClientTests: XCTestCase {
27562753
let task = client.execute(request: request, delegate: TestHTTPDelegate())
27572754

27582755
XCTAssertThrowsError(try task.wait()) { error in
2759-
#if os(Linux)
2756+
if isTestingNIOTS() {
2757+
XCTAssertEqual(error as? ChannelError, .connectTimeout(.milliseconds(200)))
2758+
} else {
2759+
//XCTAssertEqual((error as? IOError).map { $0.errnoCode }, ECONNRESET)
27602760
XCTAssertEqual(error as? NIOSSLError, NIOSSLError.uncleanShutdown)
2761-
#else
2762-
if isTestingNIOTS() {
2763-
XCTAssertEqual(error as? ChannelError, .connectTimeout(.milliseconds(200)))
2764-
} else {
2765-
XCTAssertEqual((error as? IOError).map { $0.errnoCode }, ECONNRESET)
2766-
}
2767-
#endif
2761+
}
27682762
}
27692763
}
27702764

@@ -2793,4 +2787,40 @@ class HTTPClientTests: XCTestCase {
27932787
let result = group.wait(timeout: DispatchTime.now() + DispatchTimeInterval.milliseconds(100))
27942788
XCTAssertEqual(result, .success, "we never closed the connection!")
27952789
}
2790+
2791+
func testBiDirectionalStreaming() throws {
2792+
let handler = HTTPEchoHandler()
2793+
2794+
let server = try ServerBootstrap(group: self.serverGroup)
2795+
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
2796+
.childChannelInitializer { channel in
2797+
channel.pipeline.configureHTTPServerPipeline().flatMap {
2798+
channel.pipeline.addHandler(handler)
2799+
}
2800+
}
2801+
.bind(host: "localhost", port: 0)
2802+
.wait()
2803+
2804+
defer {
2805+
server.close(promise: nil)
2806+
}
2807+
2808+
let body: HTTPClient.Body = .stream { writer in
2809+
let promise = self.clientGroup.next().makePromise(of: Void.self)
2810+
handler.promises.append(promise)
2811+
return writer.write(.byteBuffer(ByteBuffer(string: "hello"))).flatMap {
2812+
promise.futureResult
2813+
}.flatMap {
2814+
let promise = self.clientGroup.next().makePromise(of: Void.self)
2815+
handler.promises.append(promise)
2816+
return writer.write(.byteBuffer(ByteBuffer(string: "hello2"))).flatMap {
2817+
promise.futureResult
2818+
}
2819+
}
2820+
}
2821+
2822+
let future = self.defaultClient.execute(url: "http://localhost:\(server.localAddress!.port!)", body: body)
2823+
2824+
XCTAssertNoThrow(try future.wait())
2825+
}
27962826
}

0 commit comments

Comments
 (0)