Skip to content

Commit 06a4098

Browse files
committed
Fixes bi-directional streaming
Motivation: When we stream request body, current implementation expects that body will finish streaming _before_ we start to receive response body parts. This is not correct, reponse body parts can start to arrive before we finish sending the request. Modifications: - Simplifies state machine, we only case about request being fully sent to prevent sending body parts after .end, but response state machine is mostly ignored and correct flow will be handled by NIOHTTP1 pipeline - Adds HTTPEchoHandler, that replies to each response body part - Adds bi-directional streaming test Result: Closes #327
1 parent ba845ee commit 06a4098

File tree

5 files changed

+106
-34
lines changed

5 files changed

+106
-34
lines changed

Sources/AsyncHTTPClient/HTTPHandler.swift

+36-17
Original file line numberDiff line numberDiff line change
@@ -665,11 +665,11 @@ internal struct TaskCancelEvent {}
665665
internal class TaskHandler<Delegate: HTTPClientResponseDelegate>: RemovableChannelHandler {
666666
enum State {
667667
case idle
668-
case bodySent
669-
case sent
670-
case head
668+
case sendingBodyWaitingResponseHead
669+
case sendingBodyResponseHeadReceived
670+
case bodySentWaitingResponseHead
671+
case bodySentResponseHeadReceived
671672
case redirected(HTTPResponseHead, URL)
672-
case body
673673
case endOrError
674674
}
675675

@@ -794,7 +794,8 @@ extension TaskHandler: ChannelDuplexHandler {
794794
typealias OutboundOut = HTTPClientRequestPart
795795

796796
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
797-
self.state = .idle
797+
self.state = .sendingBodyWaitingResponseHead
798+
798799
let request = self.unwrapOutboundIn(data)
799800

800801
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1),
@@ -840,23 +841,32 @@ extension TaskHandler: ChannelDuplexHandler {
840841
self.writeBody(request: request, context: context)
841842
}.flatMap {
842843
context.eventLoop.assertInEventLoop()
843-
if case .endOrError = self.state {
844+
845+
switch self.state {
846+
case .idle:
847+
preconditionFailure("should not happen")
848+
case .sendingBodyWaitingResponseHead:
849+
self.state = .bodySentWaitingResponseHead
850+
case .sendingBodyResponseHeadReceived:
851+
self.state = .bodySentResponseHeadReceived
852+
case .bodySentWaitingResponseHead, .bodySentResponseHeadReceived:
853+
preconditionFailure("should not happen")
854+
case .endOrError, .redirected:
844855
return context.eventLoop.makeSucceededFuture(())
845856
}
846857

847-
self.state = .bodySent
848858
if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength {
849859
let error = HTTPClientError.bodyLengthMismatch
850860
return context.eventLoop.makeFailedFuture(error)
851861
}
852862
return context.writeAndFlush(self.wrapOutboundOut(.end(nil)))
853863
}.map {
854864
context.eventLoop.assertInEventLoop()
865+
855866
if case .endOrError = self.state {
856867
return
857868
}
858869

859-
self.state = .sent
860870
self.callOutToDelegateFireAndForget(self.delegate.didSendRequest)
861871
}.flatMapErrorThrowing { error in
862872
context.eventLoop.assertInEventLoop()
@@ -903,6 +913,8 @@ extension TaskHandler: ChannelDuplexHandler {
903913
private func writeBodyPart(context: ChannelHandlerContext, part: IOData, promise: EventLoopPromise<Void>) {
904914
switch self.state {
905915
case .idle:
916+
preconditionFailure("should not happen")
917+
case .sendingBodyWaitingResponseHead, .sendingBodyResponseHeadReceived:
906918
if let limit = self.expectedBodyLength, self.actualBodyLength + part.readableBytes > limit {
907919
let error = HTTPClientError.bodyLengthMismatch
908920
self.errorCaught(context: context, error: error)
@@ -911,7 +923,7 @@ extension TaskHandler: ChannelDuplexHandler {
911923
}
912924
self.actualBodyLength += part.readableBytes
913925
context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise)
914-
default:
926+
case .bodySentWaitingResponseHead, .bodySentResponseHeadReceived, .redirected, .endOrError:
915927
let error = HTTPClientError.writeAfterRequestSent
916928
self.errorCaught(context: context, error: error)
917929
promise.fail(error)
@@ -931,7 +943,16 @@ extension TaskHandler: ChannelDuplexHandler {
931943
let response = self.unwrapInboundIn(data)
932944
switch response {
933945
case .head(let head):
934-
if case .endOrError = self.state {
946+
switch self.state {
947+
case .idle:
948+
preconditionFailure("should not happen")
949+
case .sendingBodyWaitingResponseHead:
950+
self.state = .sendingBodyResponseHeadReceived
951+
case .bodySentWaitingResponseHead:
952+
self.state = .bodySentResponseHeadReceived
953+
case .sendingBodyResponseHeadReceived, .bodySentResponseHeadReceived, .redirected:
954+
preconditionFailure("should not happen")
955+
case .endOrError:
935956
return
936957
}
937958

@@ -942,7 +963,6 @@ extension TaskHandler: ChannelDuplexHandler {
942963
if let redirectURL = self.redirectHandler?.redirectTarget(status: head.status, headers: head.headers) {
943964
self.state = .redirected(head, redirectURL)
944965
} else {
945-
self.state = .head
946966
self.mayRead = false
947967
self.callOutToDelegate(value: head, channelEventLoop: context.eventLoop, self.delegate.didReceiveHead)
948968
.whenComplete { result in
@@ -954,7 +974,6 @@ extension TaskHandler: ChannelDuplexHandler {
954974
case .redirected, .endOrError:
955975
break
956976
default:
957-
self.state = .body
958977
self.mayRead = false
959978
self.callOutToDelegate(value: body, channelEventLoop: context.eventLoop, self.delegate.didReceiveBodyPart)
960979
.whenComplete { result in
@@ -1009,10 +1028,10 @@ extension TaskHandler: ChannelDuplexHandler {
10091028

10101029
func channelInactive(context: ChannelHandlerContext) {
10111030
switch self.state {
1031+
case .idle, .sendingBodyWaitingResponseHead, .sendingBodyResponseHeadReceived, .bodySentWaitingResponseHead, .bodySentResponseHeadReceived, .redirected:
1032+
self.errorCaught(context: context, error: HTTPClientError.remoteConnectionClosed)
10121033
case .endOrError:
10131034
break
1014-
case .body, .head, .idle, .redirected, .sent, .bodySent:
1015-
self.errorCaught(context: context, error: HTTPClientError.remoteConnectionClosed)
10161035
}
10171036
context.fireChannelInactive()
10181037
}
@@ -1025,8 +1044,8 @@ extension TaskHandler: ChannelDuplexHandler {
10251044
/// Some HTTP Servers can 'forget' to respond with CloseNotify when client is closing connection,
10261045
/// this could lead to incomplete SSL shutdown. But since request is already processed, we can ignore this error.
10271046
break
1028-
case .head where self.ignoreUncleanSSLShutdown,
1029-
.body where self.ignoreUncleanSSLShutdown:
1047+
case .sendingBodyResponseHeadReceived where self.ignoreUncleanSSLShutdown,
1048+
.bodySentResponseHeadReceived where self.ignoreUncleanSSLShutdown:
10301049
/// We can also ignore this error like `.end`.
10311050
break
10321051
default:
@@ -1035,7 +1054,7 @@ extension TaskHandler: ChannelDuplexHandler {
10351054
}
10361055
default:
10371056
switch self.state {
1038-
case .idle, .bodySent, .sent, .head, .redirected, .body:
1057+
case .idle, .sendingBodyWaitingResponseHead, .sendingBodyResponseHeadReceived, .bodySentWaitingResponseHead, .bodySentResponseHeadReceived, .redirected:
10391058
self.state = .endOrError
10401059
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
10411060
case .endOrError:

Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class HTTPClientInternalTests: XCTestCase {
145145

146146
try channel.pipeline.addHandler(handler).wait()
147147

148-
handler.state = .sent
148+
handler.state = .bodySentWaitingResponseHead
149149
var body = channel.allocator.buffer(capacity: 4)
150150
body.writeStaticString("1234")
151151

Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift

+24
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,30 @@ struct CollectEverythingLogHandler: LogHandler {
896896
}
897897
}
898898

899+
class HTTPEchoHandler: ChannelInboundHandler {
900+
typealias InboundIn = HTTPServerRequestPart
901+
typealias OutboundOut = HTTPServerResponsePart
902+
903+
var promises: CircularBuffer<EventLoopPromise<Void>> = CircularBuffer()
904+
905+
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
906+
let request = self.unwrapInboundIn(data)
907+
switch request {
908+
case .head:
909+
context.writeAndFlush(self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), promise: nil)
910+
case .body(let bytes):
911+
context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(bytes)))).whenSuccess {
912+
if let promise = self.promises.popFirst() {
913+
promise.succeed(())
914+
}
915+
}
916+
case .end:
917+
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
918+
context.close(promise: nil)
919+
}
920+
}
921+
}
922+
899923
private let cert = """
900924
-----BEGIN CERTIFICATE-----
901925
MIICmDCCAYACCQCPC8JDqMh1zzANBgkqhkiG9w0BAQsFADANMQswCQYDVQQGEwJ1

Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift

+1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ extension HTTPClientTests {
128128
("testSSLHandshakeErrorPropagation", testSSLHandshakeErrorPropagation),
129129
("testSSLHandshakeErrorPropagationDelayedClose", testSSLHandshakeErrorPropagationDelayedClose),
130130
("testWeCloseConnectionsWhenConnectionCloseSetByServer", testWeCloseConnectionsWhenConnectionCloseSetByServer),
131+
("testBiDirectionalStreaming", testBiDirectionalStreaming),
131132
]
132133
}
133134
}

Tests/AsyncHTTPClientTests/HTTPClientTests.swift

+44-16
Original file line numberDiff line numberDiff line change
@@ -2708,15 +2708,11 @@ class HTTPClientTests: XCTestCase {
27082708
let task = client.execute(request: request, delegate: TestHTTPDelegate())
27092709

27102710
XCTAssertThrowsError(try task.wait()) { error in
2711-
#if os(Linux)
2711+
if isTestingNIOTS() {
2712+
XCTAssertEqual(error as? ChannelError, .connectTimeout(.milliseconds(100)))
2713+
} else {
27122714
XCTAssertEqual(error as? NIOSSLError, NIOSSLError.uncleanShutdown)
2713-
#else
2714-
if isTestingNIOTS() {
2715-
XCTAssertEqual(error as? ChannelError, .connectTimeout(.milliseconds(100)))
2716-
} else {
2717-
XCTAssertEqual((error as? IOError).map { $0.errnoCode }, ECONNRESET)
2718-
}
2719-
#endif
2715+
}
27202716
}
27212717
}
27222718

@@ -2756,15 +2752,11 @@ class HTTPClientTests: XCTestCase {
27562752
let task = client.execute(request: request, delegate: TestHTTPDelegate())
27572753

27582754
XCTAssertThrowsError(try task.wait()) { error in
2759-
#if os(Linux)
2755+
if isTestingNIOTS() {
2756+
XCTAssertEqual(error as? ChannelError, .connectTimeout(.milliseconds(200)))
2757+
} else {
27602758
XCTAssertEqual(error as? NIOSSLError, NIOSSLError.uncleanShutdown)
2761-
#else
2762-
if isTestingNIOTS() {
2763-
XCTAssertEqual(error as? ChannelError, .connectTimeout(.milliseconds(200)))
2764-
} else {
2765-
XCTAssertEqual((error as? IOError).map { $0.errnoCode }, ECONNRESET)
2766-
}
2767-
#endif
2759+
}
27682760
}
27692761
}
27702762

@@ -2793,4 +2785,40 @@ class HTTPClientTests: XCTestCase {
27932785
let result = group.wait(timeout: DispatchTime.now() + DispatchTimeInterval.milliseconds(100))
27942786
XCTAssertEqual(result, .success, "we never closed the connection!")
27952787
}
2788+
2789+
func testBiDirectionalStreaming() throws {
2790+
let handler = HTTPEchoHandler()
2791+
2792+
let server = try ServerBootstrap(group: self.serverGroup)
2793+
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
2794+
.childChannelInitializer { channel in
2795+
channel.pipeline.configureHTTPServerPipeline().flatMap {
2796+
channel.pipeline.addHandler(handler)
2797+
}
2798+
}
2799+
.bind(host: "localhost", port: 0)
2800+
.wait()
2801+
2802+
defer {
2803+
server.close(promise: nil)
2804+
}
2805+
2806+
let body: HTTPClient.Body = .stream { writer in
2807+
let promise = self.clientGroup.next().makePromise(of: Void.self)
2808+
handler.promises.append(promise)
2809+
return writer.write(.byteBuffer(ByteBuffer(string: "hello"))).flatMap {
2810+
promise.futureResult
2811+
}.flatMap {
2812+
let promise = self.clientGroup.next().makePromise(of: Void.self)
2813+
handler.promises.append(promise)
2814+
return writer.write(.byteBuffer(ByteBuffer(string: "hello2"))).flatMap {
2815+
promise.futureResult
2816+
}
2817+
}
2818+
}
2819+
2820+
let future = self.defaultClient.execute(url: "http://localhost:\(server.localAddress!.port!)", body: body)
2821+
2822+
XCTAssertNoThrow(try future.wait())
2823+
}
27962824
}

0 commit comments

Comments
 (0)