Skip to content

Commit fbf633f

Browse files
fix: make complement() use forEachUnfiltered when used in a precomputed stream
1 parent f126f08 commit fbf633f

File tree

9 files changed

+155
-12
lines changed

9 files changed

+155
-12
lines changed

core/src/main/java/ai/timefold/solver/core/api/score/stream/bi/BiConstraintStream.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import ai.timefold.solver.core.api.score.stream.tri.TriConstraintStream;
3434
import ai.timefold.solver.core.api.score.stream.tri.TriJoiner;
3535
import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream;
36+
import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStream;
3637
import ai.timefold.solver.core.impl.util.ConstantLambdaUtils;
3738

3839
import org.jspecify.annotations.NonNull;
@@ -1582,9 +1583,19 @@ <C> TriConstraintStream<A, B, C> concat(@NonNull TriConstraintStream<A, B, C> ot
15821583
@NonNull Function<A, B> paddingFunction) {
15831584
var firstStream = this;
15841585
var remapped = firstStream.map(ConstantLambdaUtils.biPickFirst());
1585-
var secondStream = getConstraintFactory().forEach(otherClass)
1586-
.ifNotExists(remapped, Joiners.equal());
1587-
return firstStream.concat(secondStream, paddingFunction);
1586+
1587+
if (firstStream instanceof AbstractConstraintStream<?> abstractConstraintStream) {
1588+
var secondStream = switch (abstractConstraintStream.getRetrievalSemantics()) {
1589+
case STANDARD, LEGACY -> getConstraintFactory().forEach(otherClass);
1590+
case STATIC -> getConstraintFactory().forEachUnfiltered(otherClass);
1591+
};
1592+
return firstStream.concat(secondStream.ifNotExists(remapped, Joiners.equal()),
1593+
paddingFunction);
1594+
} else {
1595+
var secondStream = getConstraintFactory().forEach(otherClass)
1596+
.ifNotExists(remapped, Joiners.equal());
1597+
return firstStream.concat(secondStream, paddingFunction);
1598+
}
15881599
}
15891600

15901601
// ************************************************************************

core/src/main/java/ai/timefold/solver/core/api/score/stream/quad/QuadConstraintStream.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import ai.timefold.solver.core.api.score.stream.penta.PentaJoiner;
3535
import ai.timefold.solver.core.api.score.stream.tri.TriConstraintStream;
3636
import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream;
37+
import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStream;
3738
import ai.timefold.solver.core.impl.util.ConstantLambdaUtils;
3839

3940
import org.jspecify.annotations.NonNull;
@@ -1083,9 +1084,19 @@ QuadConstraintStream<A, B, C, D> concat(@NonNull TriConstraintStream<A, B, C> ot
10831084
@NonNull Function<A, D> paddingFunctionD) {
10841085
var firstStream = this;
10851086
var remapped = firstStream.map(ConstantLambdaUtils.quadPickFirst());
1086-
var secondStream = getConstraintFactory().forEach(otherClass)
1087-
.ifNotExists(remapped, Joiners.equal());
1088-
return firstStream.concat(secondStream, paddingFunctionB, paddingFunctionC, paddingFunctionD);
1087+
1088+
if (firstStream instanceof AbstractConstraintStream<?> abstractConstraintStream) {
1089+
var secondStream = switch (abstractConstraintStream.getRetrievalSemantics()) {
1090+
case STANDARD, LEGACY -> getConstraintFactory().forEach(otherClass);
1091+
case STATIC -> getConstraintFactory().forEachUnfiltered(otherClass);
1092+
};
1093+
return firstStream.concat(secondStream.ifNotExists(remapped, Joiners.equal()),
1094+
paddingFunctionB, paddingFunctionC, paddingFunctionD);
1095+
} else {
1096+
var secondStream = getConstraintFactory().forEach(otherClass)
1097+
.ifNotExists(remapped, Joiners.equal());
1098+
return firstStream.concat(secondStream, paddingFunctionB, paddingFunctionC, paddingFunctionD);
1099+
}
10891100
}
10901101

10911102
// ************************************************************************

core/src/main/java/ai/timefold/solver/core/api/score/stream/tri/TriConstraintStream.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import ai.timefold.solver.core.api.score.stream.quad.QuadConstraintStream;
3535
import ai.timefold.solver.core.api.score.stream.quad.QuadJoiner;
3636
import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream;
37+
import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStream;
3738
import ai.timefold.solver.core.impl.util.ConstantLambdaUtils;
3839

3940
import org.jspecify.annotations.NonNull;
@@ -1294,9 +1295,19 @@ TriConstraintStream<A, B, C> concat(@NonNull BiConstraintStream<A, B> otherStrea
12941295
@NonNull Function<A, B> paddingFunctionB, @NonNull Function<A, C> paddingFunctionC) {
12951296
var firstStream = this;
12961297
var remapped = firstStream.map(ConstantLambdaUtils.triPickFirst());
1297-
var secondStream = getConstraintFactory().forEach(otherClass)
1298-
.ifNotExists(remapped, Joiners.equal());
1299-
return firstStream.concat(secondStream, paddingFunctionB, paddingFunctionC);
1298+
1299+
if (firstStream instanceof AbstractConstraintStream<?> abstractConstraintStream) {
1300+
var secondStream = switch (abstractConstraintStream.getRetrievalSemantics()) {
1301+
case STANDARD, LEGACY -> getConstraintFactory().forEach(otherClass);
1302+
case STATIC -> getConstraintFactory().forEachUnfiltered(otherClass);
1303+
};
1304+
return firstStream.concat(secondStream.ifNotExists(remapped, Joiners.equal()),
1305+
paddingFunctionB, paddingFunctionC);
1306+
} else {
1307+
var secondStream = getConstraintFactory().forEach(otherClass)
1308+
.ifNotExists(remapped, Joiners.equal());
1309+
return firstStream.concat(secondStream, paddingFunctionB, paddingFunctionC);
1310+
}
13001311
}
13011312

13021313
// ************************************************************************

core/src/main/java/ai/timefold/solver/core/api/score/stream/uni/UniConstraintStream.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import ai.timefold.solver.core.api.score.stream.bi.BiJoiner;
3434
import ai.timefold.solver.core.api.score.stream.quad.QuadConstraintStream;
3535
import ai.timefold.solver.core.api.score.stream.tri.TriConstraintStream;
36+
import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStream;
3637

3738
import org.jspecify.annotations.NonNull;
3839

@@ -1727,9 +1728,17 @@ default UniConstraintStream<A> ifNotExistsOtherIncludingNullVars(Class<A> otherC
17271728
*/
17281729
default @NonNull UniConstraintStream<A> complement(@NonNull Class<A> otherClass) {
17291730
var firstStream = this;
1730-
var secondStream = getConstraintFactory().forEach(otherClass)
1731-
.ifNotExists(firstStream, Joiners.equal());
1732-
return firstStream.concat(secondStream);
1731+
if (firstStream instanceof AbstractConstraintStream<?> abstractConstraintStream) {
1732+
var secondStream = switch (abstractConstraintStream.getRetrievalSemantics()) {
1733+
case STANDARD, LEGACY -> getConstraintFactory().forEach(otherClass);
1734+
case STATIC -> getConstraintFactory().forEachUnfiltered(otherClass);
1735+
};
1736+
return firstStream.concat(secondStream.ifNotExists(firstStream, Joiners.equal()));
1737+
} else {
1738+
var secondStream = getConstraintFactory().forEach(otherClass)
1739+
.ifNotExists(firstStream, Joiners.equal());
1740+
return firstStream.concat(secondStream);
1741+
}
17331742
}
17341743

17351744
// ************************************************************************

core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/ConstraintStreamFunctionalTest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ default void precompute_filter_3_changed() {
202202

203203
void precompute_distinct();
204204

205+
void precompute_complement();
206+
205207
void penalizeUnweighted();
206208

207209
void penalizeUnweightedLong();

core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamTest.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3651,4 +3651,29 @@ public void precompute_distinct() {
36513651
(entity, joinedValue) -> joinedValue)
36523652
.distinct());
36533653
}
3654+
3655+
@Override
3656+
@TestTemplate
3657+
public void precompute_complement() {
3658+
var solution = TestdataLavishSolution.generateEmptySolution();
3659+
var entityWithoutGroup = new TestdataLavishEntity();
3660+
var entityWithGroup1 = new TestdataLavishEntity();
3661+
var entityWithGroup2 = new TestdataLavishEntity();
3662+
var entityGroup = new TestdataLavishEntityGroup();
3663+
entityWithGroup1.setEntityGroup(entityGroup);
3664+
entityWithGroup2.setEntityGroup(entityGroup);
3665+
solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2));
3666+
solution.getEntityGroupList().add(entityGroup);
3667+
var value = new TestdataLavishValue();
3668+
solution.getValueList().add(value);
3669+
3670+
assertPrecompute(solution, List.of(
3671+
new Pair<>(entityWithGroup1, value),
3672+
new Pair<>(entityWithGroup2, value),
3673+
new Pair<>(entityWithoutGroup, null)),
3674+
pf -> pf.forEachUnfiltered(TestdataLavishEntity.class)
3675+
.join(TestdataLavishValue.class)
3676+
.filter((entity, joinedValue) -> entity.getEntityGroup() != null)
3677+
.complement(TestdataLavishEntity.class));
3678+
}
36543679
}

core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/quad/AbstractQuadConstraintStreamTest.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2394,6 +2394,33 @@ public void precompute_distinct() {
23942394
.distinct());
23952395
}
23962396

2397+
@Override
2398+
@TestTemplate
2399+
public void precompute_complement() {
2400+
var solution = TestdataLavishSolution.generateEmptySolution();
2401+
var entityWithoutGroup = new TestdataLavishEntity();
2402+
var entityWithGroup1 = new TestdataLavishEntity();
2403+
var entityWithGroup2 = new TestdataLavishEntity();
2404+
var entityGroup = new TestdataLavishEntityGroup();
2405+
entityWithGroup1.setEntityGroup(entityGroup);
2406+
entityWithGroup2.setEntityGroup(entityGroup);
2407+
solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2));
2408+
solution.getEntityGroupList().add(entityGroup);
2409+
var value = new TestdataLavishValue();
2410+
solution.getValueList().add(value);
2411+
2412+
assertPrecompute(solution, List.of(
2413+
new Quadruple<>(entityWithGroup1, value, value, value),
2414+
new Quadruple<>(entityWithGroup2, value, value, value),
2415+
new Quadruple<>(entityWithoutGroup, null, null, null)),
2416+
pf -> pf.forEachUnfiltered(TestdataLavishEntity.class)
2417+
.join(TestdataLavishValue.class)
2418+
.join(TestdataLavishValue.class)
2419+
.join(TestdataLavishValue.class)
2420+
.filter((entity, joinedValue1, joinedValue2, joinedValue3) -> entity.getEntityGroup() != null)
2421+
.complement(TestdataLavishEntity.class));
2422+
}
2423+
23972424
@Override
23982425
@TestTemplate
23992426
public void penalizeUnweighted() {

core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/tri/AbstractTriConstraintStreamTest.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2648,6 +2648,33 @@ public void precompute_distinct() {
26482648
.distinct());
26492649
}
26502650

2651+
@Override
2652+
@TestTemplate
2653+
public void precompute_complement() {
2654+
var solution = TestdataLavishSolution.generateEmptySolution();
2655+
var entityWithoutGroup = new TestdataLavishEntity();
2656+
var entityWithGroup1 = new TestdataLavishEntity();
2657+
var entityWithGroup2 = new TestdataLavishEntity();
2658+
var entityGroup = new TestdataLavishEntityGroup();
2659+
entityWithGroup1.setEntityGroup(entityGroup);
2660+
entityWithGroup2.setEntityGroup(entityGroup);
2661+
solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2));
2662+
solution.getEntityGroupList().add(entityGroup);
2663+
var value = new TestdataLavishValue();
2664+
solution.getValueList().add(value);
2665+
2666+
assertPrecompute(solution, List.of(
2667+
new Triple<>(entityWithGroup1, value, value),
2668+
new Triple<>(entityWithGroup2, value, value),
2669+
new Triple<>(entityWithoutGroup, null, null)),
2670+
pf -> pf.forEachUnfiltered(TestdataLavishEntity.class)
2671+
.join(TestdataLavishValue.class)
2672+
.join(TestdataLavishValue.class)
2673+
.filter((entity, joinedValue1, joinedValue2) -> entity.getEntityGroup() != null)
2674+
.complement(TestdataLavishEntity.class)
2675+
.distinct());
2676+
}
2677+
26512678
@Override
26522679
@TestTemplate
26532680
public void penalizeUnweighted() {

core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/uni/AbstractUniConstraintStreamTest.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4039,6 +4039,26 @@ public void precompute_distinct() {
40394039
.distinct());
40404040
}
40414041

4042+
@Override
4043+
@TestTemplate
4044+
public void precompute_complement() {
4045+
var solution = TestdataLavishSolution.generateEmptySolution();
4046+
var entityWithoutGroup = new TestdataLavishEntity();
4047+
var entityWithGroup1 = new TestdataLavishEntity();
4048+
var entityWithGroup2 = new TestdataLavishEntity();
4049+
var entityGroup = new TestdataLavishEntityGroup();
4050+
entityWithGroup1.setEntityGroup(entityGroup);
4051+
entityWithGroup2.setEntityGroup(entityGroup);
4052+
solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2));
4053+
solution.getEntityGroupList().add(entityGroup);
4054+
solution.getValueList().add(new TestdataLavishValue());
4055+
4056+
assertPrecompute(solution, List.of(entityWithGroup1, entityWithGroup2, entityWithoutGroup),
4057+
pf -> pf.forEachUnfiltered(TestdataLavishEntity.class)
4058+
.filter(entity -> entity.getEntityGroup() != null)
4059+
.complement(TestdataLavishEntity.class));
4060+
}
4061+
40424062
@TestTemplate
40434063
public void constraintProvidedFromUnknownPackage() throws ClassNotFoundException, NoSuchMethodException,
40444064
InvocationTargetException, IllegalAccessException {

0 commit comments

Comments
 (0)