Skip to content

Ensure Sink.contextView is propagated #1450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@

import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;

import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -48,9 +46,9 @@ public void subscribe(final Subscriber<? super T> subscriber) {
if (calculateDemand(demand) > 0 && inProgress.compareAndSet(false, true)) {
if (batchCursor == null) {
int batchSize = calculateBatchSize(sink.requestedFromDownstream());
Context initialContext = subscriber instanceof CoreSubscriber<?>
? ((CoreSubscriber<?>) subscriber).currentContext() : null;
batchCursorPublisher.batchCursor(batchSize).subscribe(bc -> {
batchCursorPublisher.batchCursor(batchSize)
.contextWrite(sink.contextView())
.subscribe(bc -> {
batchCursor = bc;
inProgress.set(false);

Expand All @@ -60,7 +58,7 @@ public void subscribe(final Subscriber<? super T> subscriber) {
} else {
recurseCursor();
}
}, sink::error, null, initialContext);
}, sink::error);
} else {
inProgress.set(false);
recurseCursor();
Expand All @@ -86,6 +84,7 @@ private void recurseCursor(){
} else {
batchCursor.setBatchSize(calculateBatchSize(sink.requestedFromDownstream()));
Mono.from(batchCursor.next(() -> sink.isCancelled()))
.contextWrite(sink.contextView())
.doOnCancel(this::closeCursor)
.subscribe(results -> {
if (!results.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,17 @@ public TimeoutMode getTimeoutMode() {

public Publisher<T> first() {
return batchCursor(this::asAsyncFirstReadOperation)
.flatMap(batchCursor -> Mono.create(sink -> {
.flatMap(batchCursor -> {
batchCursor.setBatchSize(1);
Mono.from(batchCursor.next())
return Mono.from(batchCursor.next())
.doOnTerminate(batchCursor::close)
.doOnError(sink::error)
.doOnSuccess(results -> {
.flatMap(results -> {
if (results == null || results.isEmpty()) {
sink.success();
} else {
sink.success(results.get(0));
return Mono.empty();
}
})
.contextWrite(sink.contextView())
.subscribe();
}));
return Mono.fromCallable(() -> results.get(0));
});
});
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ private void collInfo(final MongoCryptContext cryptContext,
sink.error(new IllegalStateException("Missing database name"));
} else {
collectionInfoRetriever.filter(databaseName, cryptContext.getMongoOperation(), operationTimeout)
.contextWrite(sink.contextView())
.doOnSuccess(result -> {
if (result != null) {
cryptContext.addMongoOperationResult(result);
Expand All @@ -328,6 +329,7 @@ private void mark(final MongoCryptContext cryptContext,
sink.error(wrapInClientException(new IllegalStateException("Missing database name")));
} else {
commandMarker.mark(databaseName, cryptContext.getMongoOperation(), operationTimeout)
.contextWrite(sink.contextView())
.doOnSuccess(result -> {
cryptContext.addMongoOperationResult(result);
cryptContext.completeMongoOperation();
Expand All @@ -343,6 +345,7 @@ private void fetchKeys(final MongoCryptContext cryptContext,
final MonoSink<RawBsonDocument> sink,
@Nullable final Timeout operationTimeout) {
keyRetriever.find(cryptContext.getMongoOperation(), operationTimeout)
.contextWrite(sink.contextView())
.doOnSuccess(results -> {
for (BsonDocument result : results) {
cryptContext.addMongoOperationResult(result);
Expand All @@ -361,11 +364,13 @@ private void decryptKeys(final MongoCryptContext cryptContext,
MongoKeyDecryptor keyDecryptor = cryptContext.nextKeyDecryptor();
if (keyDecryptor != null) {
keyManagementService.decryptKey(keyDecryptor, operationTimeout)
.contextWrite(sink.contextView())
.doOnSuccess(r -> decryptKeys(cryptContext, databaseName, sink, operationTimeout))
.doOnError(e -> sink.error(wrapInClientException(e)))
.subscribe();
} else {
Mono.fromRunnable(cryptContext::completeKeyDecryptors)
.contextWrite(sink.contextView())
.doOnSuccess(r -> executeStateMachineWithSink(cryptContext, databaseName, sink, operationTimeout))
.doOnError(e -> sink.error(wrapInClientException(e)))
.subscribe();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@
import java.util.Date;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;

import static com.mongodb.ReadPreference.primary;
import static com.mongodb.assertions.Assertions.notNull;
Expand Down Expand Up @@ -106,7 +103,7 @@ public BsonValue getId() {

@Override
public void subscribe(final Subscriber<? super Void> s) {
Mono.defer(() -> {
Mono.deferContextual(ctx -> {
AtomicBoolean terminated = new AtomicBoolean(false);
Timeout timeout = TimeoutContext.startTimeout(timeoutMs);
return createCheckAndCreateIndexesMono(timeout)
Expand All @@ -120,7 +117,7 @@ public void subscribe(final Subscriber<? super Void> s) {
return originalError;
})
.then(Mono.error(originalError)))
.doOnCancel(() -> createCancellationMono(terminated, timeout).subscribe())
.doOnCancel(() -> createCancellationMono(terminated, timeout).contextWrite(ctx).subscribe())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cancellation has a side effect - so we need to propagate context there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't us calling subscribe() here not with a subscriber that was given to us (the s argument), the only thing that requires us to propagate the context from the subscriber s? I suspect that the fact that createCancellationMono has a side-effect (it mutates terminated) is irrelevant. If the side-effect is relevant to context propagation, then could you please explain why?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On cancellation of the main Publisher we are calling subscribe() on a async clean up Publisher and we ignore any errors, results / completion.

If the side-effect is relevant to context propagation, then could you please explain why?

The user may have tracing and want to trace from the web app to the driver. Without adding the context here we break that chain and they cannot tie the cleanup operations to a web request.

.then();
}).subscribe(s);
}
Expand Down Expand Up @@ -149,38 +146,15 @@ public void subscribe(final Subscriber<? super ObjectId> subscriber) {
}

private Mono<Void> createCheckAndCreateIndexesMono(@Nullable final Timeout timeout) {
AtomicBoolean collectionExists = new AtomicBoolean(false);
return Mono.create(sink -> findAllInCollection(filesCollection, timeout).subscribe(
d -> collectionExists.set(true),
sink::error,
() -> {
if (collectionExists.get()) {
sink.success();
} else {
checkAndCreateIndex(filesCollection.withReadPreference(primary()), FILES_INDEX, timeout)
.doOnSuccess(i -> checkAndCreateIndex(chunksCollection.withReadPreference(primary()), CHUNKS_INDEX, timeout)
.subscribe(unused -> {}, sink::error, sink::success))
.subscribe(unused -> {}, sink::error);
}
})
);
}

private Mono<Document> findAllInCollection(final MongoCollection<GridFSFile> collection, @Nullable final Timeout timeout) {
return collectionWithTimeoutDeferred(collection
.withDocumentClass(Document.class)
.withReadPreference(primary()), timeout)
.flatMap(wrappedCollection -> {
if (clientSession != null) {
return Mono.from(wrappedCollection.find(clientSession)
.projection(PROJECTION)
.first());
} else {
return Mono.from(wrappedCollection.find()
.projection(PROJECTION)
.first());
}
});
return collectionWithTimeoutDeferred(filesCollection.withDocumentClass(Document.class).withReadPreference(primary()), timeout)
.map(collection -> clientSession != null ? collection.find(clientSession) : collection.find())
.flatMap(findPublisher -> Mono.from(findPublisher.projection(PROJECTION).first()))
.switchIfEmpty(Mono.defer(() ->
checkAndCreateIndex(filesCollection.withReadPreference(primary()), FILES_INDEX, timeout)
.then(checkAndCreateIndex(chunksCollection.withReadPreference(primary()), CHUNKS_INDEX, timeout))
.then(Mono.empty())
))
.then();
}

private <T> Mono<Boolean> hasIndex(final MongoCollection<T> collection, final Document index, @Nullable final Timeout timeout) {
Expand Down Expand Up @@ -228,40 +202,37 @@ private <T> Mono<String> createIndexMono(final MongoCollection<T> collection, fi
}

private Mono<Long> createSaveChunksMono(final AtomicBoolean terminated, @Nullable final Timeout timeout) {
return Mono.create(sink -> {
AtomicLong lengthInBytes = new AtomicLong(0);
AtomicInteger chunkIndex = new AtomicInteger(0);
new ResizingByteBufferFlux(source, chunkSizeBytes)
.takeUntilOther(createMonoTimer(timeout))
.flatMap((Function<ByteBuffer, Publisher<InsertOneResult>>) byteBuffer -> {
if (terminated.get()) {
return Mono.empty();
}
byte[] byteArray = new byte[byteBuffer.remaining()];
if (byteBuffer.hasArray()) {
System.arraycopy(byteBuffer.array(), byteBuffer.position(), byteArray, 0, byteBuffer.remaining());
} else {
byteBuffer.mark();
byteBuffer.get(byteArray);
byteBuffer.reset();
}
Binary data = new Binary(byteArray);
lengthInBytes.addAndGet(data.length());
return new ResizingByteBufferFlux(source, chunkSizeBytes)
.takeUntilOther(createMonoTimer(timeout))
.index()
.flatMap(indexAndBuffer -> {
if (terminated.get()) {
return Mono.empty();
}
Long index = indexAndBuffer.getT1();
ByteBuffer byteBuffer = indexAndBuffer.getT2();
byte[] byteArray = new byte[byteBuffer.remaining()];
if (byteBuffer.hasArray()) {
System.arraycopy(byteBuffer.array(), byteBuffer.position(), byteArray, 0, byteBuffer.remaining());
} else {
byteBuffer.mark();
byteBuffer.get(byteArray);
byteBuffer.reset();
}
Binary data = new Binary(byteArray);

Document chunkDocument = new Document("files_id", fileId)
.append("n", chunkIndex.getAndIncrement())
.append("data", data);
Document chunkDocument = new Document("files_id", fileId)
.append("n", index.intValue())
.append("data", data);

if (clientSession == null) {
return collectionWithTimeout(chunksCollection, timeout, TIMEOUT_ERROR_MESSAGE).insertOne(chunkDocument);
} else {
return collectionWithTimeout(chunksCollection, timeout, TIMEOUT_ERROR_MESSAGE).insertOne(clientSession,
chunkDocument);
}
Publisher<InsertOneResult> insertOnePublisher = clientSession == null
? collectionWithTimeout(chunksCollection, timeout, TIMEOUT_ERROR_MESSAGE).insertOne(chunkDocument)
: collectionWithTimeout(chunksCollection, timeout, TIMEOUT_ERROR_MESSAGE)
.insertOne(clientSession, chunkDocument);

})
.subscribe(null, sink::error, () -> sink.success(lengthInBytes.get()));
});
return Mono.from(insertOnePublisher).thenReturn(data.length());
})
.reduce(0L, Long::sum);
}

/**
Expand Down