Skip to content

Commit ee160ee

Browse files
committed
protected iterator with actor
1 parent 8abc335 commit ee160ee

File tree

1 file changed

+27
-36
lines changed

1 file changed

+27
-36
lines changed

Sources/AsyncAlgorithms/AsyncSharedSequence.swift

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ extension AsyncSharedSequence: AsyncSequence, Sendable {
129129
guard let state else { return nil }
130130
let command = state.run(id)
131131
switch command {
132-
case .fetch(var iterator):
133-
let upstreamResult = await Result.async { try await iterator.next() }
132+
case .fetch(let iterator):
133+
let upstreamResult = await iterator.next()
134134
let output = state.fetch(id, resumedWithResult: upstreamResult, iterator: iterator)
135135
return try processOutput(output)
136136
case .wait:
@@ -194,9 +194,9 @@ fileprivate extension AsyncSharedSequence {
194194
deinit { action() }
195195
}
196196

197-
struct SharedUpstreamIterator: Sendable {
197+
actor SharedUpstreamIterator {
198198

199-
private enum State: @unchecked Sendable {
199+
private enum State {
200200
case pending
201201
case active(Base.AsyncIterator)
202202
case terminal
@@ -207,36 +207,37 @@ fileprivate extension AsyncSharedSequence {
207207
return true
208208
}
209209

210-
private let createIterator: @Sendable () -> Base.AsyncIterator
210+
private let base: Base
211211
private var state = State.pending
212212

213-
init(_ createIterator: @escaping @Sendable () -> Base.AsyncIterator) {
214-
self.createIterator = createIterator
213+
init(_ base: Base) {
214+
self.base = base
215215
}
216216

217-
mutating func next() async rethrows -> Element? {
217+
func next() async -> Result<Element?, Error> {
218218
switch state {
219219
case .pending:
220-
self.state = .active(createIterator())
221-
return try await next()
220+
self.state = .active(base.makeAsyncIterator())
221+
return await next()
222222
case .active(var iterator):
223-
let result = await Result.async { try await iterator.next() }
224-
switch result {
225-
case .success(_?):
226-
self.state = .active(iterator)
227-
default:
223+
do {
224+
if let element = try await iterator.next() {
225+
self.state = .active(iterator)
226+
return .success(element)
227+
}
228+
else {
229+
self.state = .terminal
230+
return .success(nil)
231+
}
232+
}
233+
catch {
228234
self.state = .terminal
235+
return .failure(error)
229236
}
230-
return try result._rethrowGet()
231237
case .terminal:
232-
return nil
238+
return .success(nil)
233239
}
234240
}
235-
236-
mutating func reset() {
237-
guard case .active(_) = state else { return }
238-
self.state = .pending
239-
}
240241
}
241242

242243
struct Runner {
@@ -294,6 +295,7 @@ fileprivate extension AsyncSharedSequence {
294295

295296
private struct Storage: Sendable {
296297

298+
let base: Base
297299
let replayCount: Int
298300
let iteratorDiscardPolicy: IteratorDisposalPolicy
299301
var iterator: SharedUpstreamIterator?
@@ -311,9 +313,10 @@ fileprivate extension AsyncSharedSequence {
311313

312314
init(_ base: Base, replayCount: Int, discardsIterator: IteratorDisposalPolicy) {
313315
precondition(replayCount >= 0, "history must be greater than or equal to zero")
316+
self.base = base
314317
self.replayCount = replayCount
315318
self.iteratorDiscardPolicy = discardsIterator
316-
self.iterator = .init { base.makeAsyncIterator() }
319+
self.iterator = .init(base)
317320
}
318321

319322
mutating func establish() -> Connection {
@@ -483,7 +486,7 @@ fileprivate extension AsyncSharedSequence {
483486
self.currentGroup.flip()
484487
self.phase = .pending
485488
if runners.isEmpty && iteratorDiscardPolicy == .whenTerminatedOrVacant {
486-
self.iterator?.reset()
489+
self.iterator = SharedUpstreamIterator(base)
487490
self.history.removeAll()
488491
}
489492
}
@@ -564,15 +567,3 @@ fileprivate extension AsyncSharedSequence {
564567
}
565568
}
566569
}
567-
568-
fileprivate extension Result where Failure == Error {
569-
570-
static func async(_ operation: @escaping () async throws -> Success) async -> Self {
571-
do {
572-
return .success(try await operation())
573-
}
574-
catch let error {
575-
return .failure(error)
576-
}
577-
}
578-
}

0 commit comments

Comments
 (0)