Skip to content

Making the Subscribers use a common base class #154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
191 changes: 104 additions & 87 deletions src/main/java/org/dataloader/DataLoaderHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,13 @@ CompletableFuture<V> load(K key, Object loadContext) {
}
}

@SuppressWarnings("unchecked")
Object getCacheKey(K key) {
return loaderOptions.cacheKeyFunction().isPresent() ?
loaderOptions.cacheKeyFunction().get().getKey(key) : key;
}

@SuppressWarnings("unchecked")
Object getCacheKeyWithContext(K key, Object context) {
return loaderOptions.cacheKeyFunction().isPresent() ?
loaderOptions.cacheKeyFunction().get().getKeyWithContext(key, context) : key;
Expand Down Expand Up @@ -511,6 +513,7 @@ private CompletableFuture<List<V>> invokeBatchPublisher(List<K> keys, List<Objec

BatchLoaderScheduler batchLoaderScheduler = loaderOptions.getBatchLoaderScheduler();
if (batchLoadFunction instanceof BatchPublisherWithContext) {
//noinspection unchecked
BatchPublisherWithContext<K, V> loadFunction = (BatchPublisherWithContext<K, V>) batchLoadFunction;
if (batchLoaderScheduler != null) {
BatchLoaderScheduler.ScheduledBatchPublisherCall loadCall = () -> loadFunction.load(keys, subscriber, environment);
Expand All @@ -519,6 +522,7 @@ private CompletableFuture<List<V>> invokeBatchPublisher(List<K> keys, List<Objec
loadFunction.load(keys, subscriber, environment);
}
} else {
//noinspection unchecked
BatchPublisher<K, V> loadFunction = (BatchPublisher<K, V>) batchLoadFunction;
if (batchLoaderScheduler != null) {
BatchLoaderScheduler.ScheduledBatchPublisherCall loadCall = () -> loadFunction.load(keys, subscriber);
Expand All @@ -536,6 +540,7 @@ private CompletableFuture<List<V>> invokeMappedBatchPublisher(List<K> keys, List

BatchLoaderScheduler batchLoaderScheduler = loaderOptions.getBatchLoaderScheduler();
if (batchLoadFunction instanceof MappedBatchPublisherWithContext) {
//noinspection unchecked
MappedBatchPublisherWithContext<K, V> loadFunction = (MappedBatchPublisherWithContext<K, V>) batchLoadFunction;
if (batchLoaderScheduler != null) {
BatchLoaderScheduler.ScheduledBatchPublisherCall loadCall = () -> loadFunction.load(keys, subscriber, environment);
Expand All @@ -544,6 +549,7 @@ private CompletableFuture<List<V>> invokeMappedBatchPublisher(List<K> keys, List
loadFunction.load(keys, subscriber, environment);
}
} else {
//noinspection unchecked
MappedBatchPublisher<K, V> loadFunction = (MappedBatchPublisher<K, V>) batchLoadFunction;
if (batchLoaderScheduler != null) {
BatchLoaderScheduler.ScheduledBatchPublisherCall loadCall = () -> loadFunction.load(keys, subscriber);
Expand Down Expand Up @@ -618,24 +624,23 @@ private static <T> DispatchResult<T> emptyDispatchResult() {
return (DispatchResult<T>) EMPTY_DISPATCH_RESULT;
}

private class DataLoaderSubscriber implements Subscriber<V> {
private abstract class DataLoaderSubscriberBase<T> implements Subscriber<T> {

private final CompletableFuture<List<V>> valuesFuture;
private final List<K> keys;
private final List<Object> callContexts;
private final List<CompletableFuture<V>> queuedFutures;
final CompletableFuture<List<V>> valuesFuture;
final List<K> keys;
final List<Object> callContexts;
final List<CompletableFuture<V>> queuedFutures;

private final List<K> clearCacheKeys = new ArrayList<>();
private final List<V> completedValues = new ArrayList<>();
private int idx = 0;
private boolean onErrorCalled = false;
private boolean onCompleteCalled = false;
List<K> clearCacheKeys = new ArrayList<>();
List<V> completedValues = new ArrayList<>();
boolean onErrorCalled = false;
boolean onCompleteCalled = false;

private DataLoaderSubscriber(
CompletableFuture<List<V>> valuesFuture,
List<K> keys,
List<Object> callContexts,
List<CompletableFuture<V>> queuedFutures
DataLoaderSubscriberBase(
CompletableFuture<List<V>> valuesFuture,
List<K> keys,
List<Object> callContexts,
List<CompletableFuture<V>> queuedFutures
) {
this.valuesFuture = valuesFuture;
this.keys = keys;
Expand All @@ -648,55 +653,97 @@ public void onSubscribe(Subscription subscription) {
subscription.request(keys.size());
}

// onNext may be called by multiple threads - for the time being, we pass 'synchronized' to guarantee
// correctness (at the cost of speed).
@Override
public synchronized void onNext(V value) {
public void onNext(T v) {
assertState(!onErrorCalled, () -> "onError has already been called; onNext may not be invoked.");
assertState(!onCompleteCalled, () -> "onComplete has already been called; onNext may not be invoked.");
}

K key = keys.get(idx);
Object callContext = callContexts.get(idx);
CompletableFuture<V> future = queuedFutures.get(idx);
@Override
public void onComplete() {
assertState(!onErrorCalled, () -> "onError has already been called; onComplete may not be invoked.");
onCompleteCalled = true;
}

@Override
public void onError(Throwable throwable) {
assertState(!onCompleteCalled, () -> "onComplete has already been called; onError may not be invoked.");
onErrorCalled = true;

stats.incrementBatchLoadExceptionCount(new IncrementBatchLoadExceptionCountStatisticsContext<>(keys, callContexts));
}

/*
* A value has arrived - how do we complete the future that's associated with it in a common way
*/
void onNextValue(K key, V value, Object callContext, List<CompletableFuture<V>> futures) {
if (value instanceof Try) {
// we allow the batch loader to return a Try so we can better represent a computation
// that might have worked or not.
//noinspection unchecked
Try<V> tryValue = (Try<V>) value;
if (tryValue.isSuccess()) {
future.complete(tryValue.get());
futures.forEach(f -> f.complete(tryValue.get()));
} else {
stats.incrementLoadErrorCount(new IncrementLoadErrorCountStatisticsContext<>(key, callContext));
future.completeExceptionally(tryValue.getThrowable());
clearCacheKeys.add(keys.get(idx));
futures.forEach(f -> f.completeExceptionally(tryValue.getThrowable()));
clearCacheKeys.add(key);
}
} else {
future.complete(value);
futures.forEach(f -> f.complete(value));
}
}

Throwable unwrapThrowable(Throwable ex) {
if (ex instanceof CompletionException) {
ex = ex.getCause();
}
return ex;
}
}

private class DataLoaderSubscriber extends DataLoaderSubscriberBase<V> {

private int idx = 0;

private DataLoaderSubscriber(
CompletableFuture<List<V>> valuesFuture,
List<K> keys,
List<Object> callContexts,
List<CompletableFuture<V>> queuedFutures
) {
super(valuesFuture, keys, callContexts, queuedFutures);
}

// onNext may be called by multiple threads - for the time being, we pass 'synchronized' to guarantee
// correctness (at the cost of speed).
@Override
public synchronized void onNext(V value) {
super.onNext(value);

K key = keys.get(idx);
Object callContext = callContexts.get(idx);
CompletableFuture<V> future = queuedFutures.get(idx);
onNextValue(key, value, callContext, List.of(future));

completedValues.add(value);
idx++;
}

@Override
public void onComplete() {
assertState(!onErrorCalled, () -> "onError has already been called; onComplete may not be invoked.");
onCompleteCalled = true;

@Override
public synchronized void onComplete() {
super.onComplete();
assertResultSize(keys, completedValues);

possiblyClearCacheEntriesOnExceptions(clearCacheKeys);
valuesFuture.complete(completedValues);
}

@Override
public void onError(Throwable ex) {
assertState(!onCompleteCalled, () -> "onComplete has already been called; onError may not be invoked.");
onErrorCalled = true;

stats.incrementBatchLoadExceptionCount(new IncrementBatchLoadExceptionCountStatisticsContext<>(keys, callContexts));
if (ex instanceof CompletionException) {
ex = ex.getCause();
}
public synchronized void onError(Throwable ex) {
super.onError(ex);
ex = unwrapThrowable(ex);
// Set the remaining keys to the exception.
for (int i = idx; i < queuedFutures.size(); i++) {
K key = keys.get(i);
Expand All @@ -705,33 +752,25 @@ public void onError(Throwable ex) {
// clear any cached view of this key because they all failed
dataLoader.clear(key);
}
valuesFuture.completeExceptionally(ex);
}

}

private class DataLoaderMapEntrySubscriber implements Subscriber<Map.Entry<K, V>> {
private final CompletableFuture<List<V>> valuesFuture;
private final List<K> keys;
private final List<Object> callContexts;
private final List<CompletableFuture<V>> queuedFutures;
private class DataLoaderMapEntrySubscriber extends DataLoaderSubscriberBase<Map.Entry<K, V>> {

private final Map<K, Object> callContextByKey;
private final Map<K, List<CompletableFuture<V>>> queuedFuturesByKey;

private final List<K> clearCacheKeys = new ArrayList<>();
private final Map<K, V> completedValuesByKey = new HashMap<>();
private boolean onErrorCalled = false;
private boolean onCompleteCalled = false;


private DataLoaderMapEntrySubscriber(
CompletableFuture<List<V>> valuesFuture,
List<K> keys,
List<Object> callContexts,
List<CompletableFuture<V>> queuedFutures
CompletableFuture<List<V>> valuesFuture,
List<K> keys,
List<Object> callContexts,
List<CompletableFuture<V>> queuedFutures
) {
this.valuesFuture = valuesFuture;
this.keys = keys;
this.callContexts = callContexts;
this.queuedFutures = queuedFutures;

super(valuesFuture, keys, callContexts, queuedFutures);
this.callContextByKey = new HashMap<>();
this.queuedFuturesByKey = new HashMap<>();
for (int idx = 0; idx < queuedFutures.size(); idx++) {
Expand All @@ -743,42 +782,24 @@ private DataLoaderMapEntrySubscriber(
}
}

@Override
public void onSubscribe(Subscription subscription) {
subscription.request(keys.size());
}

@Override
public void onNext(Map.Entry<K, V> entry) {
assertState(!onErrorCalled, () -> "onError has already been called; onNext may not be invoked.");
assertState(!onCompleteCalled, () -> "onComplete has already been called; onNext may not be invoked.");
public synchronized void onNext(Map.Entry<K, V> entry) {
super.onNext(entry);
K key = entry.getKey();
V value = entry.getValue();

Object callContext = callContextByKey.get(key);
List<CompletableFuture<V>> futures = queuedFuturesByKey.get(key);
if (value instanceof Try) {
// we allow the batch loader to return a Try so we can better represent a computation
// that might have worked or not.
Try<V> tryValue = (Try<V>) value;
if (tryValue.isSuccess()) {
futures.forEach(f -> f.complete(tryValue.get()));
} else {
stats.incrementLoadErrorCount(new IncrementLoadErrorCountStatisticsContext<>(key, callContext));
futures.forEach(f -> f.completeExceptionally(tryValue.getThrowable()));
clearCacheKeys.add(key);
}
} else {
futures.forEach(f -> f.complete(value));
}

onNextValue(key, value, callContext, futures);

completedValuesByKey.put(key, value);
}

@Override
public void onComplete() {
assertState(!onErrorCalled, () -> "onError has already been called; onComplete may not be invoked.");
onCompleteCalled = true;
public synchronized void onComplete() {
super.onComplete();

possiblyClearCacheEntriesOnExceptions(clearCacheKeys);
List<V> values = new ArrayList<>(keys.size());
Expand All @@ -790,14 +811,9 @@ public void onComplete() {
}

@Override
public void onError(Throwable ex) {
assertState(!onCompleteCalled, () -> "onComplete has already been called; onError may not be invoked.");
onErrorCalled = true;

stats.incrementBatchLoadExceptionCount(new IncrementBatchLoadExceptionCountStatisticsContext<>(keys, callContexts));
if (ex instanceof CompletionException) {
ex = ex.getCause();
}
public synchronized void onError(Throwable ex) {
super.onError(ex);
ex = unwrapThrowable(ex);
// Complete the futures for the remaining keys with the exception.
for (int idx = 0; idx < queuedFutures.size(); idx++) {
K key = keys.get(idx);
Expand All @@ -810,6 +826,7 @@ public void onError(Throwable ex) {
dataLoader.clear(key);
}
}
valuesFuture.completeExceptionally(ex);
}
}
}