Skip to content

Commit d604dd0

Browse files
paulhdkczechboy0
andauthored
EventStreams: Customisable Terminating Byte Sequence (#115)
### Motivation As discussed in apple/swift-openapi-generator#622, some APIs, e.g., ChatGPT or Claude, may return a non-JSON byte sequence to terminate a stream of events. If not handled with a workaround (see below)such non-JSON terminating byte sequences cause a decoding error. ### Modifications This PR adds the ability to customise the terminating byte sequence by providing a closure to `asDecodedServerSentEvents()` as well as `asDecodedServerSentEventsWithJSONData()` that can match incoming data for the terminating byte sequence before it is decoded into JSON, for instance. ### Result Instead of having to decode and re-encode incoming events to filter out the terminating byte sequence - as seen in apple/swift-openapi-generator#622 (comment) - terminating byte sequences can now be cleanly caught by either providing a closure or providing the terminating byte sequence directly when calling `asDecodedServerSentEvents()` and `asDecodedServerSentEventsWithJSONData()`. ### Test Plan This PR includes unit tests that test the new function parameters as part of the existing tests for `asDecodedServerSentEvents()` as well as `asDecodedServerSentEventsWithJSONData()`. --------- Co-authored-by: Honza Dvorsky <[email protected]>
1 parent da2e5b8 commit d604dd0

File tree

3 files changed

+140
-27
lines changed

3 files changed

+140
-27
lines changed

Sources/OpenAPIRuntime/Deprecated/Deprecated.swift

+34
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,37 @@ extension Configuration {
5959
)
6060
}
6161
}
62+
63+
extension AsyncSequence where Element == ArraySlice<UInt8>, Self: Sendable {
64+
/// Returns another sequence that decodes each event's data as the provided type using the provided decoder.
65+
///
66+
/// Use this method if the event's `data` field is not JSON, or if you don't want to parse it using `asDecodedServerSentEventsWithJSONData`.
67+
/// - Returns: A sequence that provides the events.
68+
@available(*, deprecated, renamed: "asDecodedServerSentEvents(while:)") @_disfavoredOverload
69+
public func asDecodedServerSentEvents() -> ServerSentEventsDeserializationSequence<
70+
ServerSentEventsLineDeserializationSequence<Self>
71+
> { asDecodedServerSentEvents(while: { _ in true }) }
72+
/// Returns another sequence that decodes each event's data as the provided type using the provided decoder.
73+
///
74+
/// Use this method if the event's `data` field is JSON.
75+
/// - Parameters:
76+
/// - dataType: The type to decode the JSON data into.
77+
/// - decoder: The JSON decoder to use.
78+
/// - Returns: A sequence that provides the events with the decoded JSON data.
79+
@available(*, deprecated, renamed: "asDecodedServerSentEventsWithJSONData(of:decoder:while:)") @_disfavoredOverload
80+
public func asDecodedServerSentEventsWithJSONData<JSONDataType: Decodable>(
81+
of dataType: JSONDataType.Type = JSONDataType.self,
82+
decoder: JSONDecoder = .init()
83+
) -> AsyncThrowingMapSequence<
84+
ServerSentEventsDeserializationSequence<ServerSentEventsLineDeserializationSequence<Self>>,
85+
ServerSentEventWithJSONData<JSONDataType>
86+
> { asDecodedServerSentEventsWithJSONData(of: dataType, decoder: decoder, while: { _ in true }) }
87+
}
88+
89+
extension ServerSentEventsDeserializationSequence {
90+
/// Creates a new sequence.
91+
/// - Parameter upstream: The upstream sequence of arbitrary byte chunks.
92+
@available(*, deprecated, renamed: "init(upstream:while:)") @_disfavoredOverload public init(upstream: Upstream) {
93+
self.init(upstream: upstream, while: { _ in true })
94+
}
95+
}

Sources/OpenAPIRuntime/EventStreams/ServerSentEventsDecoding.swift

+50-21
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,19 @@ where Upstream.Element == ArraySlice<UInt8> {
2828
/// The upstream sequence.
2929
private let upstream: Upstream
3030

31+
/// A closure that determines whether the given byte chunk should be forwarded to the consumer.
32+
/// - Parameter: A byte chunk.
33+
/// - Returns: `true` if the byte chunk should be forwarded, `false` if this byte chunk is the terminating sequence.
34+
private let predicate: @Sendable (ArraySlice<UInt8>) -> Bool
35+
3136
/// Creates a new sequence.
32-
/// - Parameter upstream: The upstream sequence of arbitrary byte chunks.
33-
public init(upstream: Upstream) { self.upstream = upstream }
37+
/// - Parameters:
38+
/// - upstream: The upstream sequence of arbitrary byte chunks.
39+
/// - predicate: A closure that determines whether the given byte chunk should be forwarded to the consumer.
40+
public init(upstream: Upstream, while predicate: @escaping @Sendable (ArraySlice<UInt8>) -> Bool) {
41+
self.upstream = upstream
42+
self.predicate = predicate
43+
}
3444
}
3545

3646
extension ServerSentEventsDeserializationSequence: AsyncSequence {
@@ -46,7 +56,16 @@ extension ServerSentEventsDeserializationSequence: AsyncSequence {
4656
var upstream: UpstreamIterator
4757

4858
/// The state machine of the iterator.
49-
var stateMachine: StateMachine = .init()
59+
var stateMachine: StateMachine
60+
61+
/// Creates a new sequence.
62+
/// - Parameters:
63+
/// - upstream: The upstream sequence of arbitrary byte chunks.
64+
/// - predicate: A closure that determines whether the given byte chunk should be forwarded to the consumer.
65+
init(upstream: UpstreamIterator, while predicate: @escaping ((ArraySlice<UInt8>) -> Bool)) {
66+
self.upstream = upstream
67+
self.stateMachine = .init(while: predicate)
68+
}
5069

5170
/// Asynchronously advances to the next element and returns it, or ends the
5271
/// sequence if there is no next element.
@@ -70,7 +89,7 @@ extension ServerSentEventsDeserializationSequence: AsyncSequence {
7089
/// Creates the asynchronous iterator that produces elements of this
7190
/// asynchronous sequence.
7291
public func makeAsyncIterator() -> Iterator<Upstream.AsyncIterator> {
73-
Iterator(upstream: upstream.makeAsyncIterator())
92+
Iterator(upstream: upstream.makeAsyncIterator(), while: predicate)
7493
}
7594
}
7695

@@ -79,26 +98,30 @@ extension AsyncSequence where Element == ArraySlice<UInt8>, Self: Sendable {
7998
/// Returns another sequence that decodes each event's data as the provided type using the provided decoder.
8099
///
81100
/// Use this method if the event's `data` field is not JSON, or if you don't want to parse it using `asDecodedServerSentEventsWithJSONData`.
101+
/// - Parameter: A closure that determines whether the given byte chunk should be forwarded to the consumer.
82102
/// - Returns: A sequence that provides the events.
83-
public func asDecodedServerSentEvents() -> ServerSentEventsDeserializationSequence<
84-
ServerSentEventsLineDeserializationSequence<Self>
85-
> { .init(upstream: ServerSentEventsLineDeserializationSequence(upstream: self)) }
86-
103+
public func asDecodedServerSentEvents(
104+
while predicate: @escaping @Sendable (ArraySlice<UInt8>) -> Bool = { _ in true }
105+
) -> ServerSentEventsDeserializationSequence<ServerSentEventsLineDeserializationSequence<Self>> {
106+
.init(upstream: ServerSentEventsLineDeserializationSequence(upstream: self), while: predicate)
107+
}
87108
/// Returns another sequence that decodes each event's data as the provided type using the provided decoder.
88109
///
89110
/// Use this method if the event's `data` field is JSON.
90111
/// - Parameters:
91112
/// - dataType: The type to decode the JSON data into.
92113
/// - decoder: The JSON decoder to use.
114+
/// - predicate: A closure that determines whether the given byte sequence is the terminating byte sequence defined by the API.
93115
/// - Returns: A sequence that provides the events with the decoded JSON data.
94116
public func asDecodedServerSentEventsWithJSONData<JSONDataType: Decodable>(
95117
of dataType: JSONDataType.Type = JSONDataType.self,
96-
decoder: JSONDecoder = .init()
118+
decoder: JSONDecoder = .init(),
119+
while predicate: @escaping @Sendable (ArraySlice<UInt8>) -> Bool = { _ in true }
97120
) -> AsyncThrowingMapSequence<
98121
ServerSentEventsDeserializationSequence<ServerSentEventsLineDeserializationSequence<Self>>,
99122
ServerSentEventWithJSONData<JSONDataType>
100123
> {
101-
asDecodedServerSentEvents()
124+
asDecodedServerSentEvents(while: predicate)
102125
.map { event in
103126
ServerSentEventWithJSONData(
104127
event: event.event,
@@ -118,10 +141,10 @@ extension ServerSentEventsDeserializationSequence.Iterator {
118141
struct StateMachine {
119142

120143
/// The possible states of the state machine.
121-
enum State: Hashable {
144+
enum State {
122145

123146
/// Accumulating an event, which hasn't been emitted yet.
124-
case accumulatingEvent(ServerSentEvent, buffer: [ArraySlice<UInt8>])
147+
case accumulatingEvent(ServerSentEvent, buffer: [ArraySlice<UInt8>], predicate: (ArraySlice<UInt8>) -> Bool)
125148

126149
/// Finished, the terminal state.
127150
case finished
@@ -134,7 +157,9 @@ extension ServerSentEventsDeserializationSequence.Iterator {
134157
private(set) var state: State
135158

136159
/// Creates a new state machine.
137-
init() { self.state = .accumulatingEvent(.init(), buffer: []) }
160+
init(while predicate: @escaping (ArraySlice<UInt8>) -> Bool) {
161+
self.state = .accumulatingEvent(.init(), buffer: [], predicate: predicate)
162+
}
138163

139164
/// An action returned by the `next` method.
140165
enum NextAction {
@@ -156,20 +181,24 @@ extension ServerSentEventsDeserializationSequence.Iterator {
156181
/// - Returns: An action to perform.
157182
mutating func next() -> NextAction {
158183
switch state {
159-
case .accumulatingEvent(var event, var buffer):
184+
case .accumulatingEvent(var event, var buffer, let predicate):
160185
guard let line = buffer.first else { return .needsMore }
161186
state = .mutating
162187
buffer.removeFirst()
163188
if line.isEmpty {
164189
// Dispatch the accumulated event.
165-
state = .accumulatingEvent(.init(), buffer: buffer)
166190
// If the last character of data is a newline, strip it.
167191
if event.data?.hasSuffix("\n") ?? false { event.data?.removeLast() }
192+
if let data = event.data, !predicate(ArraySlice(data.utf8)) {
193+
state = .finished
194+
return .returnNil
195+
}
196+
state = .accumulatingEvent(.init(), buffer: buffer, predicate: predicate)
168197
return .emitEvent(event)
169198
}
170199
if line.first! == ASCII.colon {
171200
// A comment, skip this line.
172-
state = .accumulatingEvent(event, buffer: buffer)
201+
state = .accumulatingEvent(event, buffer: buffer, predicate: predicate)
173202
return .noop
174203
}
175204
// Parse the field name and value.
@@ -193,7 +222,7 @@ extension ServerSentEventsDeserializationSequence.Iterator {
193222
}
194223
guard let value else {
195224
// An unknown type of event, skip.
196-
state = .accumulatingEvent(event, buffer: buffer)
225+
state = .accumulatingEvent(event, buffer: buffer, predicate: predicate)
197226
return .noop
198227
}
199228
// Process the field.
@@ -214,11 +243,11 @@ extension ServerSentEventsDeserializationSequence.Iterator {
214243
}
215244
default:
216245
// An unknown or invalid field, skip.
217-
state = .accumulatingEvent(event, buffer: buffer)
246+
state = .accumulatingEvent(event, buffer: buffer, predicate: predicate)
218247
return .noop
219248
}
220249
// Processed the field, continue.
221-
state = .accumulatingEvent(event, buffer: buffer)
250+
state = .accumulatingEvent(event, buffer: buffer, predicate: predicate)
222251
return .noop
223252
case .finished: return .returnNil
224253
case .mutating: preconditionFailure("Invalid state")
@@ -240,11 +269,11 @@ extension ServerSentEventsDeserializationSequence.Iterator {
240269
/// - Returns: An action to perform.
241270
mutating func receivedValue(_ value: ArraySlice<UInt8>?) -> ReceivedValueAction {
242271
switch state {
243-
case .accumulatingEvent(let event, var buffer):
272+
case .accumulatingEvent(let event, var buffer, let predicate):
244273
if let value {
245274
state = .mutating
246275
buffer.append(value)
247-
state = .accumulatingEvent(event, buffer: buffer)
276+
state = .accumulatingEvent(event, buffer: buffer, predicate: predicate)
248277
return .noop
249278
} else {
250279
// If no value is received, drop the existing event on the floor.

Tests/OpenAPIRuntimeTests/EventStreams/Test_ServerSentEventsDecoding.swift

+56-6
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,22 @@ import XCTest
1616
import Foundation
1717

1818
final class Test_ServerSentEventsDecoding: Test_Runtime {
19-
func _test(input: String, output: [ServerSentEvent], file: StaticString = #filePath, line: UInt = #line)
20-
async throws
21-
{
22-
let sequence = asOneBytePerElementSequence(ArraySlice(input.utf8)).asDecodedServerSentEvents()
19+
func _test(
20+
input: String,
21+
output: [ServerSentEvent],
22+
file: StaticString = #filePath,
23+
line: UInt = #line,
24+
while predicate: @escaping @Sendable (ArraySlice<UInt8>) -> Bool = { _ in true }
25+
) async throws {
26+
let sequence = asOneBytePerElementSequence(ArraySlice(input.utf8)).asDecodedServerSentEvents(while: predicate)
2327
let events = try await [ServerSentEvent](collecting: sequence)
2428
XCTAssertEqual(events.count, output.count, file: file, line: line)
2529
for (index, linePair) in zip(events, output).enumerated() {
2630
let (actualEvent, expectedEvent) = linePair
2731
XCTAssertEqual(actualEvent, expectedEvent, "Event: \(index)", file: file, line: line)
2832
}
2933
}
34+
3035
func test() async throws {
3136
// Simple event.
3237
try await _test(
@@ -83,22 +88,40 @@ final class Test_ServerSentEventsDecoding: Test_Runtime {
8388
.init(id: "123", data: "This is a message with an ID."),
8489
]
8590
)
91+
92+
try await _test(
93+
input: #"""
94+
data: hello
95+
data: world
96+
97+
data: [DONE]
98+
99+
data: hello2
100+
data: world2
101+
102+
103+
"""#,
104+
output: [.init(data: "hello\nworld")],
105+
while: { incomingData in incomingData != ArraySlice<UInt8>(Data("[DONE]".utf8)) }
106+
)
86107
}
87108
func _testJSONData<JSONType: Decodable & Hashable & Sendable>(
88109
input: String,
89110
output: [ServerSentEventWithJSONData<JSONType>],
90111
file: StaticString = #filePath,
91-
line: UInt = #line
112+
line: UInt = #line,
113+
while predicate: @escaping @Sendable (ArraySlice<UInt8>) -> Bool = { _ in true }
92114
) async throws {
93115
let sequence = asOneBytePerElementSequence(ArraySlice(input.utf8))
94-
.asDecodedServerSentEventsWithJSONData(of: JSONType.self)
116+
.asDecodedServerSentEventsWithJSONData(of: JSONType.self, while: predicate)
95117
let events = try await [ServerSentEventWithJSONData<JSONType>](collecting: sequence)
96118
XCTAssertEqual(events.count, output.count, file: file, line: line)
97119
for (index, linePair) in zip(events, output).enumerated() {
98120
let (actualEvent, expectedEvent) = linePair
99121
XCTAssertEqual(actualEvent, expectedEvent, "Event: \(index)", file: file, line: line)
100122
}
101123
}
124+
102125
struct TestEvent: Decodable, Hashable, Sendable { var index: Int }
103126
func testJSONData() async throws {
104127
// Simple event.
@@ -121,6 +144,33 @@ final class Test_ServerSentEventsDecoding: Test_Runtime {
121144
.init(event: "event2", data: TestEvent(index: 2), id: "2"),
122145
]
123146
)
147+
148+
try await _testJSONData(
149+
input: #"""
150+
event: event1
151+
id: 1
152+
data: {"index":1}
153+
154+
event: event2
155+
id: 2
156+
data: {
157+
data: "index": 2
158+
data: }
159+
160+
data: [DONE]
161+
162+
event: event3
163+
id: 1
164+
data: {"index":3}
165+
166+
167+
"""#,
168+
output: [
169+
.init(event: "event1", data: TestEvent(index: 1), id: "1"),
170+
.init(event: "event2", data: TestEvent(index: 2), id: "2"),
171+
],
172+
while: { incomingData in incomingData != ArraySlice<UInt8>(Data("[DONE]".utf8)) }
173+
)
124174
}
125175
}
126176

0 commit comments

Comments
 (0)