Skip to content

[Concurrency] Avoid inserting handler record in already cancelled task. #80456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions stdlib/public/Concurrency/Task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1758,27 +1758,45 @@ swift_task_addCancellationHandlerImpl(
CancellationNotificationStatusRecord(unsigned_handler, context);

bool fireHandlerNow = false;

addStatusRecordToSelf(record, [&](ActiveTaskStatus oldStatus, ActiveTaskStatus& newStatus) {
if (oldStatus.isCancelled()) {
fireHandlerNow = true;
// We don't fire the cancellation handler here since this function needs
// to be idempotent
fireHandlerNow = true;

// don't add the record, because that would risk triggering it from
// task_cancel, concurrently with the record->run() we're about to do below.
return false;
}
return true;
return true; // add the record
});

if (fireHandlerNow) {
record->run();

// we have not added the record to the task because it has fired immediately,
// and therefore we can clean it up immediately rather than wait until removeCancellationHandler
// which would be triggered at the end of the withTaskCancellationHandler block.
swift_task_dealloc(record);
return nullptr; // indicate to the remove... method, that there was no task added
}
return record;
}

SWIFT_CC(swift)
static void swift_task_removeCancellationHandlerImpl(
CancellationNotificationStatusRecord *record) {
removeStatusRecordFromSelf(record);
swift_task_dealloc(record);
if (!record) {
// seems we never added the record but have run it immediately,
// so we make no attempts to remove it.
return;
}

if (auto poppedRecord =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way the remove and add are paired is such that the record should always be at the top here, so this removal should be exactly that record on the top.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to add an assert here that task->_private()._status().load(std::memory_order_relaxed).getInnermostRecord() == record to check that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ll add that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding here #80522

popStatusRecordOfType<CancellationNotificationStatusRecord>(swift_task_getCurrent())) {
assert(record == poppedRecord && "The removed record did not match the expected record!");
swift_task_dealloc(record);
}
}

SWIFT_CC(swift)
Expand Down
10 changes: 10 additions & 0 deletions stdlib/public/Concurrency/TaskPrivate.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,16 @@ void removeStatusRecordWhere(
llvm::function_ref<bool(ActiveTaskStatus, TaskStatusRecord*)> condition,
llvm::function_ref<void(ActiveTaskStatus, ActiveTaskStatus&)>updateStatus = nullptr);

/// Remove and return a status record of the given type. This function removes a
/// singlw record, and leaves subsequent records as-is if there are any.
/// Returns `nullptr` if there are no matching records.
///
/// NOTE: When using this function with new record types, make sure to provide
/// an explicit instantiation in TaskStatus.cpp.
template <typename TaskStatusRecordT>
SWIFT_CC(swift)
TaskStatusRecordT* popStatusRecordOfType(AsyncTask *task);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to move the templated parts of the implementation into the header to avoid annoying link issues.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did try that but we don’t have the types fully declared there hmmm, I can try again. I’ll share the error

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine to leave it if C++ objects too strenuously to my idea.

Copy link
Contributor Author

@ktoso ktoso Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's a bit painful, we'd need complete type declarations for ActiveTaskStatus which is somewhat annoying to pull up completely.

/Users/ktoso/code/swift-project/swift/stdlib/public/Concurrency/TaskPrivate.h:255:27: warning: extra qualification on member 'popStatusRecordOfType' [-Wextra-qualification]
  255 | TaskStatusRecordT* swift::popStatusRecordOfType(AsyncTask *task) {
      |                           ^
/Users/ktoso/code/swift-project/swift/stdlib/public/Concurrency/TaskPrivate.h:255:27: error: out-of-line definition of 'popStatusRecordOfType' does not match any declaration in namespace 'swift'
  255 | TaskStatusRecordT* swift::popStatusRecordOfType(AsyncTask *task) {
      |                           ^~~~~~~~~~~~~~~~~~~~~
/Users/ktoso/code/swift-project/swift/stdlib/public/Concurrency/TaskPrivate.h:258:54: error: variable has incomplete type 'ActiveTaskStatus'
  258 |   removeStatusRecordWhere(task, [&](ActiveTaskStatus s, TaskStatusRecord *r) {
      |                                                      ^
/Users/ktoso/code/swift-project/swift/stdlib/public/Concurrency/TaskPrivate.h:58:7: note: forward declaration of 'swift::ActiveTaskStatus'
   58 | class ActiveTaskStatus;
      |       ^

I'll leave as is I think but thank you very much for review!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have a small templatized implementation in the header that just retrieves the kind of the type being requested, then calls into a real non-templatized implementation which looks for that kind. But we can save that for when we actually need it.


/// Remove a status record from the current task. This must be called
/// synchronously with the task.
SWIFT_CC(swift)
Expand Down
17 changes: 12 additions & 5 deletions stdlib/public/Concurrency/TaskStatus.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,15 +350,18 @@ void swift::removeStatusRecordWhere(
});
}

// Remove and return a status record of the given type. There must be at most
// one matching record. Returns nullptr if there are none.
template <typename TaskStatusRecordT>
static TaskStatusRecordT *popStatusRecordOfType(AsyncTask *task) {
SWIFT_CC(swift)
TaskStatusRecordT* swift::popStatusRecordOfType(AsyncTask *task) {
TaskStatusRecordT *record = nullptr;
bool alreadyRemovedRecord = false;
removeStatusRecordWhere(task, [&](ActiveTaskStatus s, TaskStatusRecord *r) {
if (alreadyRemovedRecord)
return false;

if (auto *match = dyn_cast<TaskStatusRecordT>(r)) {
assert(!record && "two matching records found");
record = match;
alreadyRemovedRecord = true;
return true; // Remove this record.
}

Expand Down Expand Up @@ -562,6 +565,10 @@ static void swift_task_popTaskExecutorPreferenceImpl(
swift_task_dealloc(record);
}

// Since the header would have incomplete declarations, we instead instantiate a concrete version of the function here
template SWIFT_CC(swift)
CancellationNotificationStatusRecord* swift::popStatusRecordOfType<CancellationNotificationStatusRecord>(AsyncTask *);

void AsyncTask::pushInitialTaskExecutorPreference(
TaskExecutorRef preferredExecutor, bool owned) {
void *allocation = _swift_task_alloc_specific(
Expand Down Expand Up @@ -879,7 +886,7 @@ static void swift_task_cancelImpl(AsyncTask *task) {
}

newStatus.traceStatusChanged(task, false);
if (newStatus.getInnermostRecord() == NULL) {
if (newStatus.getInnermostRecord() == nullptr) {
// No records, nothing to propagate
return;
}
Expand Down
62 changes: 62 additions & 0 deletions test/Concurrency/Runtime/cancellation_handler_only_once.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// RUN: %target-run-simple-swift( -Xfrontend -disable-availability-checking -target %target-swift-5.1-abi-triple %import-libdispatch) | %FileCheck %s
// REQUIRES: concurrency
// REQUIRES: executable_test

// rdar://76038845
// REQUIRES: concurrency_runtime
// UNSUPPORTED: back_deployment_runtime
// UNSUPPORTED: freestanding

import Synchronization

struct State {
var cancelled = 0
var continuation: CheckedContinuation<Void, Never>?
}

func testFunc(_ iteration: Int) async -> Task<Void, Never> {
let state = Mutex(State())

let task = Task {
await withTaskCancellationHandler {
await withCheckedContinuation { continuation in
let cancelled = state.withLock {
if $0.cancelled > 0 {
return true
} else {
$0.continuation = continuation
return false
}
}
if cancelled {
continuation.resume()
}
}
} onCancel: {
let continuation = state.withLock {
$0.cancelled += 1
return $0.continuation.take()
}
continuation?.resume()
}
}

// This task cancel is racing with installing the cancellation handler,
// and we may either hit the cancellation handler:
// - after this cancel was issued, and therefore the handler runs immediately
task.cancel()
_ = await task.value

let cancelled = state.withLock { $0.cancelled }
precondition(cancelled == 1, "cancelled more than once, iteration: \(iteration)")

return task
}

var ts: [Task<Void, Never>] = []
for iteration in 0..<1_000 {
let t = await testFunc(iteration)
ts.append(t)
}

print("done") // CHECK: done