Skip to content

Commit 11391b9

Browse files
committed
Replace reduce with individual reductions
1 parent 81305dd commit 11391b9

File tree

3 files changed

+214
-42
lines changed

3 files changed

+214
-42
lines changed

driver-core/src/main/com/mongodb/client/model/expressions/ArrayExpression.java

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
package com.mongodb.client.model.expressions;
1818

19-
import java.util.function.BinaryOperator;
2019
import java.util.function.Function;
2120

2221
import static com.mongodb.client.model.expressions.Expressions.of;
@@ -51,28 +50,48 @@ public interface ArrayExpression<T extends Expression> extends Expression {
5150
*/
5251
<R extends Expression> ArrayExpression<R> map(Function<? super T, ? extends R> in);
5352

53+
IntegerExpression size();
54+
55+
BooleanExpression any(Function<T, BooleanExpression> map);
56+
57+
BooleanExpression all(Function<T, BooleanExpression> map);
58+
59+
NumberExpression sum(Function<T, NumberExpression> map);
60+
61+
NumberExpression multiply(Function<T, NumberExpression> map);
62+
63+
NumberExpression max(Function<T, NumberExpression> map, NumberExpression orElse);
64+
65+
NumberExpression min(Function<T, NumberExpression> map, NumberExpression orElse);
66+
67+
StringExpression join(Function<T, StringExpression> map);
68+
69+
<R extends Expression> ArrayExpression<R> concat(Function<T, ArrayExpression<R>> map);
70+
71+
<R extends Expression> ArrayExpression<R> union(Function<T, ArrayExpression<R>> map);
72+
5473
/**
55-
* Performs a reduction on the elements of this array, using the provided
56-
* identity value and an associative reducing function, and returns
57-
* the reduced value. The initial value must be the identity value for the
58-
* reducing function.
74+
* user asserts that i is in bounds for the array
5975
*
60-
* @param initialValue the identity for the reducing function
61-
* @param in the associative reducing function
62-
* @return the reduced value
76+
* @param i
77+
* @return
6378
*/
64-
T reduce(T initialValue, BinaryOperator<T> in);
65-
66-
IntegerExpression size();
67-
6879
T elementAt(IntegerExpression i);
6980

7081
default T elementAt(final int i) {
7182
return this.elementAt(of(i));
7283
}
7384

85+
/**
86+
* user asserts that array is not empty
87+
* @return
88+
*/
7489
T first();
7590

91+
/**
92+
* user asserts that array is not empty
93+
* @return
94+
*/
7695
T last();
7796

7897
BooleanExpression contains(T contains);

driver-core/src/main/com/mongodb/client/model/expressions/MqlExpression.java

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import java.util.function.Function;
2828

2929
import static com.mongodb.client.model.expressions.Expressions.of;
30+
import static com.mongodb.client.model.expressions.Expressions.ofArray;
31+
import static com.mongodb.client.model.expressions.Expressions.ofNull;
3032
import static com.mongodb.client.model.expressions.Expressions.ofStringArray;
3133

3234
final class MqlExpression<T extends Expression>
@@ -385,7 +387,6 @@ public ArrayExpression<T> filter(final Function<? super T, ? extends BooleanExpr
385387
.append("cond", extractBsonValue(cr, cond.apply(varThis)))));
386388
}
387389

388-
@Override
389390
public T reduce(final T initialValue, final BinaryOperator<T> in) {
390391
T varThis = variable("$$this");
391392
T varValue = variable("$$value");
@@ -395,6 +396,66 @@ public T reduce(final T initialValue, final BinaryOperator<T> in) {
395396
.append("in", extractBsonValue(cr, in.apply(varValue, varThis)))));
396397
}
397398

399+
private <R extends Expression> R reduceMap(
400+
final Function<T, R> map,
401+
final R initialValue,
402+
final BinaryOperator<R> in) {
403+
MqlExpression<R> map1 = (MqlExpression<R>) this.map(map);
404+
return map1.reduce(initialValue, in);
405+
}
406+
407+
@Override
408+
public BooleanExpression any(final Function<T, BooleanExpression> map) {
409+
return reduceMap(map, of(false), (a, b) -> a.or(b));
410+
}
411+
412+
@Override
413+
public BooleanExpression all(final Function<T, BooleanExpression> map) {
414+
return reduceMap(map, of(true), (a, b) -> a.and(b));
415+
}
416+
417+
@Override
418+
public NumberExpression sum(final Function<T, NumberExpression> map) {
419+
// no sum for IntegerExpression, both have same erasure
420+
return reduceMap(map, of(0), (a, b) -> a.add(b));
421+
}
422+
423+
@Override
424+
public NumberExpression multiply(final Function<T, NumberExpression> map) {
425+
return reduceMap(map, of(0), (a, b) -> a.multiply(b));
426+
}
427+
428+
@Override
429+
public NumberExpression max(final Function<T, NumberExpression> map, final NumberExpression orElse) {
430+
return reduceMap(map,
431+
(NumberExpression) ofNull(),
432+
(a, b) -> a.max(b))
433+
.isNumberOr(orElse);
434+
}
435+
436+
@Override
437+
public NumberExpression min(final Function<T, NumberExpression> map, final NumberExpression orElse) {
438+
return reduceMap(map,
439+
(NumberExpression) ofNull(),
440+
(a, b) -> a.min(b))
441+
.isNumberOr(orElse);
442+
}
443+
444+
@Override
445+
public StringExpression join(final Function<T, StringExpression> map) {
446+
return reduceMap(map, of(""), (a, b) -> a.concat(b));
447+
}
448+
449+
@Override
450+
public <R extends Expression> ArrayExpression<R> concat(final Function<T, ArrayExpression<R>> map) {
451+
return reduceMap(map, ofArray(), (a, b) -> a.concat(b));
452+
}
453+
454+
@Override
455+
public <R extends Expression> ArrayExpression<R> union(final Function<T, ArrayExpression<R>> map) {
456+
return reduceMap(map, ofArray(), (a, b) -> a.union(b));
457+
}
458+
398459
@Override
399460
public IntegerExpression size() {
400461
return new MqlExpression<>(astWrapped("$size"));

driver-core/src/test/functional/com/mongodb/client/model/expressions/ArrayExpressionsFunctionalTest.java

Lines changed: 121 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -127,41 +127,128 @@ public void mapTest() {
127127
"{'$map': {'input': [true, true, false], 'in': {'$not': '$$this'}}}");
128128
}
129129

130+
// https://www.mongodb.com/docs/manual/reference/operator/aggregation/reduce/
131+
// reduce is implemented as each individual type of reduction (monoid)
132+
// this prevents issues related to incorrect specification of identity values
133+
130134
@Test
131-
public void reduceTest() {
132-
// https://www.mongodb.com/docs/manual/reference/operator/aggregation/reduce/
135+
public void reduceAnyTest() {
133136
assertExpression(
134-
Stream.of(true, true, false)
135-
.reduce(false, (a, b) -> a || b),
136-
arrayTTF.reduce(of(false), (a, b) -> a.or(b)),
137-
// MQL:
138-
"{'$reduce': {'input': [true, true, false], 'initialValue': false, 'in': {'$or': ['$$value', '$$this']}}}");
137+
true,
138+
arrayTTF.any(a -> a),
139+
"{'$reduce': {'input': {'$map': {'input': [true, true, false], 'in': '$$this'}}, "
140+
+ "'initialValue': false, 'in': {'$or': ['$$value', '$$this']}}}");
139141
assertExpression(
140-
Stream.of(true, true, false)
141-
.reduce(true, (a, b) -> a && b),
142-
arrayTTF.reduce(of(true), (a, b) -> a.and(b)),
143-
// MQL:
144-
"{'$reduce': {'input': [true, true, false], 'initialValue': true, 'in': {'$and': ['$$value', '$$this']}}}");
145-
// empty array
142+
false,
143+
ofBooleanArray().any(a -> a));
144+
146145
assertExpression(
147-
Stream.<Boolean>empty().reduce(true, (a, b) -> a && b),
148-
ofBooleanArray().reduce(of(true), (a, b) -> a.and(b)),
149-
// MQL:
150-
"{'$reduce': {'input': [], 'initialValue': true, 'in': {'$and': ['$$value', '$$this']}}}");
151-
// constant result
146+
true,
147+
ofIntegerArray(1, 2, 3).any(a -> a.eq(of(3))));
152148
assertExpression(
153-
Stream.of(true, true, false)
154-
.reduce(true, (a, b) -> true),
155-
arrayTTF.reduce(of(true), (a, b) -> of(true)),
156-
// MQL:
157-
"{'$reduce': {'input': [true, true, false], 'initialValue': true, 'in': true}}");
158-
// non-commutative
149+
false,
150+
ofIntegerArray(1, 2, 2).any(a -> a.eq(of(9))));
151+
}
152+
153+
@Test
154+
public void reduceAllTest() {
155+
assertExpression(
156+
false,
157+
arrayTTF.all(a -> a),
158+
"{'$reduce': {'input': {'$map': {'input': [true, true, false], 'in': '$$this'}}, "
159+
+ "'initialValue': true, 'in': {'$and': ['$$value', '$$this']}}}");
160+
assertExpression(
161+
true,
162+
ofBooleanArray().all(a -> a));
163+
164+
assertExpression(
165+
true,
166+
ofIntegerArray(1, 2, 3).all(a -> a.gt(of(0))));
167+
assertExpression(
168+
false,
169+
ofIntegerArray(1, 2, 2).all(a -> a.eq(of(2))));
170+
}
171+
172+
@Test
173+
public void reduceSumTest() {
174+
assertExpression(
175+
6,
176+
ofIntegerArray(1, 2, 3).sum(a -> a),
177+
"{'$reduce': {'input': {'$map': {'input': [1, 2, 3], 'in': '$$this'}}, "
178+
+ "'initialValue': 0, 'in': {'$add': ['$$value', '$$this']}}}");
179+
// empty array:
180+
assertExpression(
181+
0,
182+
ofIntegerArray().sum(a -> a));
183+
}
184+
185+
@Test
186+
public void reduceMaxTest() {
187+
assertExpression(
188+
3,
189+
ofIntegerArray(1, 2, 3).max(a -> a, of(9)),
190+
"{'$cond': [{'$isNumber': [{'$reduce': {'input': "
191+
+ "{'$map': {'input': [1, 2, 3], 'in': '$$this'}}, "
192+
+ "'initialValue': null, 'in': {'$max': ['$$value', '$$this']}}}]}, "
193+
+ "{'$reduce': {'input': {'$map': {'input': [1, 2, 3], 'in': '$$this'}}, "
194+
+ "'initialValue': null, 'in': {'$max': ['$$value', '$$this']}}}, 9]}");
195+
assertExpression(
196+
9,
197+
ofIntegerArray().max(a -> a, of(9)));
198+
}
199+
200+
@Test
201+
public void reduceMinTest() {
202+
assertExpression(
203+
1,
204+
ofIntegerArray(1, 2, 3).min(a -> a, of(9)),
205+
"{'$cond': [{'$isNumber': [{'$reduce': {'input': "
206+
+ "{'$map': {'input': [1, 2, 3], 'in': '$$this'}}, "
207+
+ "'initialValue': null, 'in': {'$min': ['$$value', '$$this']}}}]}, "
208+
+ "{'$reduce': {'input': {'$map': {'input': [1, 2, 3], 'in': '$$this'}}, "
209+
+ "'initialValue': null, 'in': {'$min': ['$$value', '$$this']}}}, 9]}");
210+
assertExpression(
211+
9,
212+
ofIntegerArray().min(a -> a, of(9)));
213+
}
214+
215+
@Test
216+
public void reduceJoinTest() {
159217
assertExpression(
160218
"abc",
161-
ofStringArray("a", "b", "c").reduce(of(""), (a, b) -> a.concat(b)),
162-
// MQL:
163-
"{'$reduce': {'input': ['a', 'b', 'c'], 'initialValue': '', 'in': {'$concat': ['$$value', '$$this']}}}");
219+
ofStringArray("a", "b", "c").join(a -> a),
220+
"{'$reduce': {'input': {'$map': {'input': ['a', 'b', 'c'], 'in': '$$this'}}, "
221+
+ "'initialValue': '', 'in': {'$concat': ['$$value', '$$this']}}}");
222+
assertExpression(
223+
"",
224+
ofStringArray().join(a -> a));
225+
}
164226

227+
@Test
228+
public void reduceConcatTest() {
229+
assertExpression(
230+
Arrays.asList(1, 2, 3, 4),
231+
ofArray(ofIntegerArray(1, 2), ofIntegerArray(3, 4)).concat(v -> v),
232+
"{'$reduce': {'input': {'$map': {'input': [[1, 2], [3, 4]], 'in': '$$this'}}, "
233+
+ "'initialValue': [], "
234+
+ "'in': {'$concatArrays': ['$$value', '$$this']}}} ");
235+
// empty:
236+
ArrayExpression<ArrayExpression<Expression>> expressionArrayExpression = ofArray();
237+
assertExpression(
238+
Collections.emptyList(),
239+
expressionArrayExpression.concat(a -> a));
240+
}
241+
242+
@Test
243+
public void reduceUnionTest() {
244+
// https://www.mongodb.com/docs/manual/reference/operator/aggregation/setUnion/ (40)
245+
assertExpression(
246+
Arrays.asList(1, 2, 3),
247+
ofArray(ofIntegerArray(1, 2), ofIntegerArray(1, 3)).union(v -> v),
248+
// MQL:
249+
"{'$reduce': {'input': {'$map': {'input': [[1, 2], [1, 3]], 'in': '$$this'}}, "
250+
+ "'initialValue': [], "
251+
+ "'in': {'$setUnion': ['$$value', '$$this']}}}");
165252
}
166253

167254
@Test
@@ -221,6 +308,12 @@ public void firstTest() {
221308
array123.first(),
222309
// MQL:
223310
"{'$first': [[1, 2, 3]]}");
311+
312+
assertExpression(
313+
MISSING,
314+
ofIntegerArray().first(),
315+
// MQL:
316+
"{'$first': [[]]}");
224317
}
225318

226319
@Test
@@ -289,7 +382,7 @@ public void sliceTest() {
289382
}
290383

291384
@Test
292-
public void setUnionTest() {
385+
public void unionTest() {
293386
// https://www.mongodb.com/docs/manual/reference/operator/aggregation/setUnion/
294387
assertExpression(
295388
Arrays.asList(1, 2, 3),
@@ -309,5 +402,4 @@ public void setUnionTest() {
309402
// MQL:
310403
"{'$setUnion': [[1, 2, 1, 3, 3]]}");
311404
}
312-
313405
}

0 commit comments

Comments
 (0)