Skip to content

[Backport] Fix exception propagation in Async API methods #1485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import com.mongodb.lang.Nullable;

import java.util.concurrent.atomic.AtomicBoolean;

/**
* See {@link AsyncRunnable}
* <p>
Expand All @@ -33,4 +35,28 @@ public interface AsyncFunction<T, R> {
* @param callback the callback
*/
void unsafeFinish(T value, SingleResultCallback<R> callback);

/**
* Must be invoked at end of async chain or when executing a callback handler supplied by the caller.
*
* @param callback the callback provided by the method the chain is used in.
*/
default void finish(final T value, final SingleResultCallback<R> callback) {
final AtomicBoolean callbackInvoked = new AtomicBoolean(false);
try {
this.unsafeFinish(value, (v, e) -> {
if (!callbackInvoked.compareAndSet(false, true)) {
throw new AssertionError(String.format("Callback has been already completed. It could happen "
+ "if code throws an exception after invoking an async method. Value: %s", v), e);
}
callback.onResult(v, e);
});
} catch (Throwable t) {
if (!callbackInvoked.compareAndSet(false, true)) {
throw t;
} else {
callback.completeExceptionally(t);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ default AsyncRunnable thenRun(final AsyncRunnable runnable) {
return (c) -> {
this.unsafeFinish((r, e) -> {
if (e == null) {
runnable.unsafeFinish(c);
/* If 'runnable' is executed on a different thread from the one that executed the initial 'finish()',
then invoking 'finish()' within 'runnable' will catch and propagate any exceptions to 'c' (the callback). */
runnable.finish(c);
} else {
c.completeExceptionally(e);
}
Expand Down Expand Up @@ -199,7 +201,7 @@ default AsyncRunnable thenRunIf(final Supplier<Boolean> condition, final AsyncRu
return;
}
if (matched) {
runnable.unsafeFinish(callback);
runnable.finish(callback);
} else {
callback.complete(callback);
}
Expand All @@ -216,7 +218,7 @@ default <R> AsyncSupplier<R> thenSupply(final AsyncSupplier<R> supplier) {
return (c) -> {
this.unsafeFinish((r, e) -> {
if (e == null) {
supplier.unsafeFinish(c);
supplier.finish(c);
} else {
c.completeExceptionally(e);
}
Expand Down
24 changes: 16 additions & 8 deletions driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.mongodb.lang.Nullable;

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Predicate;


Expand Down Expand Up @@ -54,18 +55,25 @@ default void unsafeFinish(@Nullable final Void value, final SingleResultCallback
}

/**
* Must be invoked at end of async chain.
* Must be invoked at end of async chain or when executing a callback handler supplied by the caller.
*
* @see #thenApply(AsyncFunction)
* @see #thenConsume(AsyncConsumer)
* @see #onErrorIf(Predicate, AsyncFunction)
* @param callback the callback provided by the method the chain is used in
*/
default void finish(final SingleResultCallback<T> callback) {
final boolean[] callbackInvoked = {false};
final AtomicBoolean callbackInvoked = new AtomicBoolean(false);
try {
this.unsafeFinish((v, e) -> {
callbackInvoked[0] = true;
if (!callbackInvoked.compareAndSet(false, true)) {
throw new AssertionError(String.format("Callback has been already completed. It could happen "
+ "if code throws an exception after invoking an async method. Value: %s", v), e);
}
callback.onResult(v, e);
});
} catch (Throwable t) {
if (callbackInvoked[0]) {
if (!callbackInvoked.compareAndSet(false, true)) {
throw t;
} else {
callback.completeExceptionally(t);
Expand All @@ -80,9 +88,9 @@ default void finish(final SingleResultCallback<T> callback) {
*/
default <R> AsyncSupplier<R> thenApply(final AsyncFunction<T, R> function) {
return (c) -> {
this.unsafeFinish((v, e) -> {
this.finish((v, e) -> {
if (e == null) {
function.unsafeFinish(v, c);
function.finish(v, c);
} else {
c.completeExceptionally(e);
}
Expand All @@ -99,7 +107,7 @@ default AsyncRunnable thenConsume(final AsyncConsumer<T> consumer) {
return (c) -> {
this.unsafeFinish((v, e) -> {
if (e == null) {
consumer.unsafeFinish(v, c);
consumer.finish(v, c);
} else {
c.completeExceptionally(e);
}
Expand Down Expand Up @@ -131,7 +139,7 @@ default AsyncSupplier<T> onErrorIf(
return;
}
if (errorMatched) {
errorFunction.unsafeFinish(e, callback);
errorFunction.finish(e, callback);
} else {
callback.completeExceptionally(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ private <T> void sendCommandMessageAsync(final int messageId, final Decoder<T> d
return;
}
assertNotNull(responseBuffers);
T commandResult;
try {
updateSessionContext(sessionContext, responseBuffers);
boolean commandOk =
Expand All @@ -609,13 +610,14 @@ private <T> void sendCommandMessageAsync(final int messageId, final Decoder<T> d
}
commandEventSender.sendSucceededEvent(responseBuffers);

T result1 = getCommandResult(decoder, responseBuffers, messageId);
callback.onResult(result1, null);
commandResult = getCommandResult(decoder, responseBuffers, messageId);
} catch (Throwable localThrowable) {
callback.onResult(null, localThrowable);
return;
} finally {
responseBuffers.close();
}
callback.onResult(commandResult, null);
}));
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,14 @@ public void startHandshakeAsync(final InternalConnection internalConnection,
callback.onResult(null, t instanceof MongoException ? mapHelloException((MongoException) t) : t);
} else {
setSpeculativeAuthenticateResponse(helloResult);
callback.onResult(createInitializationDescription(helloResult, internalConnection, startTime), null);
InternalConnectionInitializationDescription initializationDescription;
try {
initializationDescription = createInitializationDescription(helloResult, internalConnection, startTime);
} catch (Throwable localThrowable) {
callback.onResult(null, localThrowable);
return;
}
callback.onResult(initializationDescription, null);
}
});
}
Expand Down
Loading