@@ -714,7 +714,7 @@ private class DataLoaderMapEntrySubscriber implements Subscriber<Map.Entry<K, V>
714
714
private final List <Object > callContexts ;
715
715
private final List <CompletableFuture <V >> queuedFutures ;
716
716
private final Map <K , Object > callContextByKey ;
717
- private final Map <K , CompletableFuture <V >> queuedFutureByKey ;
717
+ private final Map <K , List < CompletableFuture <V >>> queuedFuturesByKey ;
718
718
719
719
private final List <K > clearCacheKeys = new ArrayList <>();
720
720
private final Map <K , V > completedValuesByKey = new HashMap <>();
@@ -733,13 +733,13 @@ private DataLoaderMapEntrySubscriber(
733
733
this .queuedFutures = queuedFutures ;
734
734
735
735
this .callContextByKey = new HashMap <>();
736
- this .queuedFutureByKey = new HashMap <>();
736
+ this .queuedFuturesByKey = new HashMap <>();
737
737
for (int idx = 0 ; idx < queuedFutures .size (); idx ++) {
738
738
K key = keys .get (idx );
739
739
Object callContext = callContexts .get (idx );
740
740
CompletableFuture <V > queuedFuture = queuedFutures .get (idx );
741
741
callContextByKey .put (key , callContext );
742
- queuedFutureByKey . put (key , queuedFuture );
742
+ queuedFuturesByKey . computeIfAbsent (key , k -> new ArrayList <>()). add ( queuedFuture );
743
743
}
744
744
}
745
745
@@ -756,20 +756,20 @@ public void onNext(Map.Entry<K, V> entry) {
756
756
V value = entry .getValue ();
757
757
758
758
Object callContext = callContextByKey .get (key );
759
- CompletableFuture <V > future = queuedFutureByKey .get (key );
759
+ List < CompletableFuture <V >> futures = queuedFuturesByKey .get (key );
760
760
if (value instanceof Try ) {
761
761
// we allow the batch loader to return a Try so we can better represent a computation
762
762
// that might have worked or not.
763
763
Try <V > tryValue = (Try <V >) value ;
764
764
if (tryValue .isSuccess ()) {
765
- future . complete (tryValue .get ());
765
+ futures . forEach ( f -> f . complete (tryValue .get () ));
766
766
} else {
767
767
stats .incrementLoadErrorCount (new IncrementLoadErrorCountStatisticsContext <>(key , callContext ));
768
- future . completeExceptionally (tryValue .getThrowable ());
768
+ futures . forEach ( f -> f . completeExceptionally (tryValue .getThrowable () ));
769
769
clearCacheKeys .add (key );
770
770
}
771
771
} else {
772
- future . complete (value );
772
+ futures . forEach ( f -> f . complete (value ) );
773
773
}
774
774
775
775
completedValuesByKey .put (key , value );
@@ -801,9 +801,11 @@ public void onError(Throwable ex) {
801
801
// Complete the futures for the remaining keys with the exception.
802
802
for (int idx = 0 ; idx < queuedFutures .size (); idx ++) {
803
803
K key = keys .get (idx );
804
- CompletableFuture <V > future = queuedFutureByKey .get (key );
804
+ List < CompletableFuture <V >> futures = queuedFuturesByKey .get (key );
805
805
if (!completedValuesByKey .containsKey (key )) {
806
- future .completeExceptionally (ex );
806
+ for (CompletableFuture <V > future : futures ) {
807
+ future .completeExceptionally (ex );
808
+ }
807
809
// clear any cached view of this key because they all failed
808
810
dataLoader .clear (key );
809
811
}
0 commit comments