Skip to content

Commit f126f08

Browse files
test: add tests for various precompute inputs
1 parent ab73433 commit f126f08

File tree

8 files changed

+1122
-154
lines changed

8 files changed

+1122
-154
lines changed

core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/quad/BavetPrecomputeQuadConstraintStream.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import java.util.Set;
44

55
import ai.timefold.solver.core.api.score.Score;
6-
import ai.timefold.solver.core.impl.bavet.bi.PrecomputeBiNode;
76
import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream;
87
import ai.timefold.solver.core.impl.bavet.common.TupleSource;
9-
import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple;
8+
import ai.timefold.solver.core.impl.bavet.common.tuple.QuadTuple;
9+
import ai.timefold.solver.core.impl.bavet.quad.PrecomputeQuadNode;
1010
import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory;
1111
import ai.timefold.solver.core.impl.score.stream.bavet.common.BavetPrecomputeBuildHelper;
1212
import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper;
@@ -34,10 +34,10 @@ public void setAftBridge(BavetAftBridgeQuadConstraintStream<Solution_, A, B, C,
3434

3535
@Override
3636
public <Score_ extends Score<Score_>> void buildNode(ConstraintNodeBuildHelper<Solution_, Score_> buildHelper) {
37-
var staticDataBuildHelper = new BavetPrecomputeBuildHelper<BiTuple<A, B>>(recordingStaticConstraintStream);
37+
var staticDataBuildHelper = new BavetPrecomputeBuildHelper<QuadTuple<A, B, C, D>>(recordingStaticConstraintStream);
3838
var outputStoreSize = buildHelper.extractTupleStoreSize(aftStream);
3939

40-
buildHelper.addNode(new PrecomputeBiNode<>(staticDataBuildHelper.getNodeNetwork(),
40+
buildHelper.addNode(new PrecomputeQuadNode<>(staticDataBuildHelper.getNodeNetwork(),
4141
staticDataBuildHelper.getRecordingTupleLifecycle(),
4242
outputStoreSize,
4343
buildHelper.getAggregatedTupleLifecycle(aftStream.getChildStreamList()),

core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/tri/BavetPrecomputeTriConstraintStream.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import java.util.Set;
44

55
import ai.timefold.solver.core.api.score.Score;
6-
import ai.timefold.solver.core.impl.bavet.bi.PrecomputeBiNode;
76
import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream;
87
import ai.timefold.solver.core.impl.bavet.common.TupleSource;
9-
import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple;
8+
import ai.timefold.solver.core.impl.bavet.common.tuple.TriTuple;
9+
import ai.timefold.solver.core.impl.bavet.tri.PrecomputeTriNode;
1010
import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory;
1111
import ai.timefold.solver.core.impl.score.stream.bavet.common.BavetPrecomputeBuildHelper;
1212
import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper;
@@ -33,10 +33,10 @@ public void setAftBridge(BavetAftBridgeTriConstraintStream<Solution_, A, B, C> a
3333

3434
@Override
3535
public <Score_ extends Score<Score_>> void buildNode(ConstraintNodeBuildHelper<Solution_, Score_> buildHelper) {
36-
var staticDataBuildHelper = new BavetPrecomputeBuildHelper<BiTuple<A, B>>(recordingStaticConstraintStream);
36+
var staticDataBuildHelper = new BavetPrecomputeBuildHelper<TriTuple<A, B, C>>(recordingStaticConstraintStream);
3737
var outputStoreSize = buildHelper.extractTupleStoreSize(aftStream);
3838

39-
buildHelper.addNode(new PrecomputeBiNode<>(staticDataBuildHelper.getNodeNetwork(),
39+
buildHelper.addNode(new PrecomputeTriNode<>(staticDataBuildHelper.getNodeNetwork(),
4040
staticDataBuildHelper.getRecordingTupleLifecycle(),
4141
outputStoreSize,
4242
buildHelper.getAggregatedTupleLifecycle(aftStream.getChildStreamList()),

core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/ConstraintStreamFunctionalTest.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,34 @@ default void expandToQuad() {
174174

175175
void complement();
176176

177+
void precompute_filter_0_changed();
178+
179+
default void precompute_filter_1_changed() {
180+
// requires two elements, so Bi, Tri and Quad
181+
}
182+
183+
default void precompute_filter_2_changed() {
184+
// requires three elements, so Tri and Quad
185+
}
186+
187+
default void precompute_filter_3_changed() {
188+
// requires four elements, Quad
189+
}
190+
191+
void precompute_ifExists();
192+
193+
void precompute_ifNotExists();
194+
195+
void precompute_groupBy();
196+
197+
void precompute_flattenLast();
198+
199+
void precompute_map();
200+
201+
void precompute_concat();
202+
203+
void precompute_distinct();
204+
177205
void penalizeUnweighted();
178206

179207
void penalizeUnweightedLong();

core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamTest.java

Lines changed: 265 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,14 @@
3434
import ai.timefold.solver.core.api.score.stream.ConstraintCollectors;
3535
import ai.timefold.solver.core.api.score.stream.ConstraintJustification;
3636
import ai.timefold.solver.core.api.score.stream.DefaultConstraintJustification;
37+
import ai.timefold.solver.core.api.score.stream.Joiners;
38+
import ai.timefold.solver.core.api.score.stream.PrecomputeFactory;
39+
import ai.timefold.solver.core.api.score.stream.bi.BiConstraintStream;
3740
import ai.timefold.solver.core.impl.score.director.InnerScoreDirector;
3841
import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStreamTest;
3942
import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamFunctionalTest;
4043
import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamImplSupport;
44+
import ai.timefold.solver.core.impl.util.Pair;
4145
import ai.timefold.solver.core.testdomain.TestdataEntity;
4246
import ai.timefold.solver.core.testdomain.list.unassignedvar.TestdataAllowsUnassignedValuesListEntity;
4347
import ai.timefold.solver.core.testdomain.list.unassignedvar.TestdataAllowsUnassignedValuesListSolution;
@@ -3307,8 +3311,91 @@ public void joinerEqualsAndSameness() {
33073311
assertMatch(entity3, entity2));
33083312
}
33093313

3314+
@Override
3315+
@TestTemplate
3316+
public void precompute_filter_0_changed() {
3317+
var solution = TestdataLavishSolution.generateSolution();
3318+
var entityGroup = new TestdataLavishEntityGroup("MyEntityGroup");
3319+
var valueGroup = new TestdataLavishValueGroup("MyValueGroup");
3320+
solution.getEntityGroupList().add(entityGroup);
3321+
solution.getValueGroupList().add(valueGroup);
3322+
3323+
var value1 = Mockito.spy(new TestdataLavishValue("MyValue 1", valueGroup));
3324+
solution.getValueList().add(value1);
3325+
var value2 = Mockito.spy(new TestdataLavishValue("MyValue 2", valueGroup));
3326+
solution.getValueList().add(value2);
3327+
var value3 = Mockito.spy(new TestdataLavishValue("MyValue 3", null));
3328+
solution.getValueList().add(value3);
3329+
3330+
var entity1 = Mockito.spy(new TestdataLavishEntity("MyEntity 1", entityGroup, value1));
3331+
solution.getEntityList().add(entity1);
3332+
var entity2 = new TestdataLavishEntity("MyEntity 2", entityGroup, value1);
3333+
solution.getEntityList().add(entity2);
3334+
var entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(),
3335+
value1);
3336+
solution.getEntityList().add(entity3);
3337+
3338+
var scoreDirector =
3339+
buildScoreDirector(factory -> factory.precompute(data -> data.forEachUnfiltered(TestdataLavishEntity.class)
3340+
.join(TestdataLavishValue.class)
3341+
.filter((entity, value) -> entity.getEntityGroup() == entityGroup
3342+
&& value.getValueGroup() == valueGroup))
3343+
.filter((entity, value) -> entity.getValue() == value1)
3344+
.penalize(SimpleScore.ONE)
3345+
.asConstraint(TEST_CONSTRAINT_NAME));
3346+
3347+
// From scratch
3348+
Mockito.reset(entity1);
3349+
scoreDirector.setWorkingSolution(solution);
3350+
assertScore(scoreDirector,
3351+
assertMatch(entity1, value1),
3352+
assertMatch(entity1, value2),
3353+
assertMatch(entity2, value1),
3354+
assertMatch(entity2, value2));
3355+
Mockito.verify(entity1, Mockito.atLeastOnce()).getEntityGroup();
3356+
3357+
// Incrementally update a variable
3358+
Mockito.reset(entity1);
3359+
scoreDirector.beforeVariableChanged(entity1, "value");
3360+
entity1.setValue(solution.getFirstValue());
3361+
scoreDirector.afterVariableChanged(entity1, "value");
3362+
assertScore(scoreDirector,
3363+
assertMatch(entity2, value1),
3364+
assertMatch(entity2, value2));
3365+
Mockito.verify(entity1, Mockito.never()).getEntityGroup();
3366+
3367+
// Incrementally update a fact
3368+
scoreDirector.beforeProblemPropertyChanged(entity3);
3369+
entity3.setEntityGroup(entityGroup);
3370+
scoreDirector.afterProblemPropertyChanged(entity3);
3371+
assertScore(scoreDirector,
3372+
assertMatch(entity2, value1),
3373+
assertMatch(entity2, value2),
3374+
assertMatch(entity3, value1),
3375+
assertMatch(entity3, value2));
3376+
3377+
// Remove entity
3378+
scoreDirector.beforeEntityRemoved(entity3);
3379+
solution.getEntityList().remove(entity3);
3380+
scoreDirector.afterEntityRemoved(entity3);
3381+
assertScore(scoreDirector,
3382+
assertMatch(entity2, value1),
3383+
assertMatch(entity2, value2));
3384+
3385+
// Add it back again, to make sure it was properly removed before
3386+
scoreDirector.beforeEntityAdded(entity3);
3387+
solution.getEntityList().add(entity3);
3388+
scoreDirector.afterEntityAdded(entity3);
3389+
assertScore(scoreDirector,
3390+
assertMatch(entity2, value1),
3391+
assertMatch(entity2, value2),
3392+
assertMatch(entity3, value1),
3393+
assertMatch(entity3, value2));
3394+
}
3395+
3396+
@Override
33103397
@TestTemplate
3311-
public void precompute_join_filter_map_entity_right() {
3398+
public void precompute_filter_1_changed() {
33123399
var solution = TestdataLavishSolution.generateSolution();
33133400
var entityGroup = new TestdataLavishEntityGroup("MyEntityGroup");
33143401
var valueGroup = new TestdataLavishValueGroup("MyValueGroup");
@@ -3387,4 +3474,181 @@ public void precompute_join_filter_map_entity_right() {
33873474
assertMatch(value1, entity3),
33883475
assertMatch(value2, entity3));
33893476
}
3477+
3478+
private <A, B> void assertPrecompute(TestdataLavishSolution solution,
3479+
List<Pair<A, B>> expectedValues,
3480+
Function<PrecomputeFactory, BiConstraintStream<A, B>> entityStreamSupplier) {
3481+
var scoreDirector =
3482+
buildScoreDirector(factory -> factory.precompute(entityStreamSupplier)
3483+
.ifExists(TestdataLavishEntity.class)
3484+
.penalize(SimpleScore.ONE)
3485+
.asConstraint(TEST_CONSTRAINT_NAME));
3486+
3487+
// From scratch
3488+
scoreDirector.setWorkingSolution(solution);
3489+
assertScore(scoreDirector);
3490+
3491+
for (var entity : solution.getEntityList()) {
3492+
scoreDirector.beforeVariableChanged(entity, "value");
3493+
entity.setValue(solution.getFirstValue());
3494+
scoreDirector.afterVariableChanged(entity, "value");
3495+
}
3496+
3497+
assertScore(scoreDirector, expectedValues.stream()
3498+
.map(pair -> new Object[] { pair.key(), pair.value() })
3499+
.map(AbstractConstraintStreamTest::assertMatch)
3500+
.toArray(AssertableMatch[]::new));
3501+
}
3502+
3503+
@Override
3504+
@TestTemplate
3505+
public void precompute_ifExists() {
3506+
var solution = TestdataLavishSolution.generateEmptySolution();
3507+
var entityWithoutGroup = new TestdataLavishEntity();
3508+
var entityWithGroup = new TestdataLavishEntity();
3509+
var entityGroup = new TestdataLavishEntityGroup();
3510+
entityWithGroup.setEntityGroup(entityGroup);
3511+
solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup));
3512+
solution.getEntityGroupList().add(entityGroup);
3513+
var value = new TestdataLavishValue();
3514+
solution.getValueList().add(value);
3515+
3516+
assertPrecompute(solution, List.of(new Pair<>(entityWithGroup, value)),
3517+
pf -> pf.forEachUnfiltered(TestdataLavishEntity.class)
3518+
.join(TestdataLavishValue.class)
3519+
.ifExists(TestdataLavishEntityGroup.class, Joiners.equal(
3520+
(a, b) -> a.getEntityGroup(), Function.identity())));
3521+
}
3522+
3523+
@Override
3524+
@TestTemplate
3525+
public void precompute_ifNotExists() {
3526+
var solution = TestdataLavishSolution.generateEmptySolution();
3527+
var entityWithoutGroup = new TestdataLavishEntity();
3528+
var entityWithGroup = new TestdataLavishEntity();
3529+
var entityGroup = new TestdataLavishEntityGroup();
3530+
entityWithGroup.setEntityGroup(entityGroup);
3531+
solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup));
3532+
solution.getEntityGroupList().add(entityGroup);
3533+
3534+
var value = new TestdataLavishValue();
3535+
solution.getValueList().add(value);
3536+
3537+
assertPrecompute(solution, List.of(new Pair<>(entityWithoutGroup, value)),
3538+
pf -> pf.forEachUnfiltered(TestdataLavishEntity.class)
3539+
.join(TestdataLavishValue.class)
3540+
.ifNotExists(TestdataLavishEntityGroup.class, Joiners.equal(
3541+
(a, b) -> a.getEntityGroup(), Function.identity())));
3542+
}
3543+
3544+
@Override
3545+
@TestTemplate
3546+
public void precompute_groupBy() {
3547+
var solution = TestdataLavishSolution.generateEmptySolution();
3548+
var entityWithoutGroup = new TestdataLavishEntity();
3549+
var entityWithGroup = new TestdataLavishEntity();
3550+
var entityGroup = new TestdataLavishEntityGroup();
3551+
entityWithGroup.setEntityGroup(entityGroup);
3552+
solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup));
3553+
solution.getEntityGroupList().add(entityGroup);
3554+
3555+
var value = new TestdataLavishValue();
3556+
solution.getValueList().add(value);
3557+
3558+
assertPrecompute(solution, List.of(new Pair<>(entityGroup, 1)),
3559+
pf -> pf.forEachUnfiltered(TestdataLavishEntity.class)
3560+
.filter(entity -> entity.getEntityGroup() != null)
3561+
.groupBy(TestdataLavishEntity::getEntityGroup, ConstraintCollectors.count()));
3562+
}
3563+
3564+
@Override
3565+
@TestTemplate
3566+
public void precompute_flattenLast() {
3567+
var solution = TestdataLavishSolution.generateEmptySolution();
3568+
var entityWithoutGroup = new TestdataLavishEntity();
3569+
var entityWithGroup = new TestdataLavishEntity();
3570+
var entityGroup = new TestdataLavishEntityGroup();
3571+
entityWithGroup.setEntityGroup(entityGroup);
3572+
solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup));
3573+
solution.getEntityGroupList().add(entityGroup);
3574+
var value = new TestdataLavishValue();
3575+
solution.getValueList().add(value);
3576+
3577+
assertPrecompute(solution, List.of(new Pair<>(entityWithoutGroup, value),
3578+
new Pair<>(entityWithGroup, value)),
3579+
pf -> pf.forEachUnfiltered(TestdataLavishEntity.class)
3580+
.groupBy(ConstraintCollectors.toList())
3581+
.flattenLast(entityList -> entityList)
3582+
.join(TestdataLavishValue.class));
3583+
}
3584+
3585+
@Override
3586+
@TestTemplate
3587+
public void precompute_map() {
3588+
var solution = TestdataLavishSolution.generateEmptySolution();
3589+
var entityWithoutGroup = new TestdataLavishEntity();
3590+
var entityWithGroup1 = new TestdataLavishEntity();
3591+
var entityWithGroup2 = new TestdataLavishEntity();
3592+
var entityGroup = new TestdataLavishEntityGroup();
3593+
entityWithGroup1.setEntityGroup(entityGroup);
3594+
entityWithGroup2.setEntityGroup(entityGroup);
3595+
solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2));
3596+
solution.getEntityGroupList().add(entityGroup);
3597+
var value = new TestdataLavishValue();
3598+
solution.getValueList().add(value);
3599+
3600+
assertPrecompute(solution, List.of(new Pair<>(entityGroup, value),
3601+
new Pair<>(entityGroup, value)),
3602+
pf -> pf.forEachUnfiltered(TestdataLavishEntity.class)
3603+
.join(TestdataLavishValue.class)
3604+
.filter((entity, joinedValue) -> entity.getEntityGroup() != null)
3605+
.map((entity, joinedValue) -> entity.getEntityGroup(),
3606+
(entity, joinedValue) -> joinedValue));
3607+
}
3608+
3609+
@Override
3610+
@TestTemplate
3611+
public void precompute_concat() {
3612+
var solution = TestdataLavishSolution.generateEmptySolution();
3613+
var entityWithoutGroup = new TestdataLavishEntity();
3614+
var entityWithGroup = new TestdataLavishEntity();
3615+
var entityGroup = new TestdataLavishEntityGroup();
3616+
entityWithGroup.setEntityGroup(entityGroup);
3617+
solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup));
3618+
solution.getEntityGroupList().add(entityGroup);
3619+
var value = new TestdataLavishValue();
3620+
solution.getValueList().add(value);
3621+
3622+
assertPrecompute(solution, List.of(new Pair<>(entityWithoutGroup, value), new Pair<>(entityWithGroup, value)),
3623+
pf -> pf.forEachUnfiltered(TestdataLavishEntity.class)
3624+
.join(TestdataLavishValue.class)
3625+
.filter((entity, joinedValue) -> entity.getEntityGroup() == null)
3626+
.concat(pf.forEachUnfiltered(TestdataLavishEntity.class)
3627+
.join(TestdataLavishValue.class)
3628+
.filter((entity, joinedValue) -> entity.getEntityGroup() != null)));
3629+
}
3630+
3631+
@Override
3632+
@TestTemplate
3633+
public void precompute_distinct() {
3634+
var solution = TestdataLavishSolution.generateEmptySolution();
3635+
var entityWithoutGroup = new TestdataLavishEntity();
3636+
var entityWithGroup1 = new TestdataLavishEntity();
3637+
var entityWithGroup2 = new TestdataLavishEntity();
3638+
var entityGroup = new TestdataLavishEntityGroup();
3639+
entityWithGroup1.setEntityGroup(entityGroup);
3640+
entityWithGroup2.setEntityGroup(entityGroup);
3641+
solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2));
3642+
solution.getEntityGroupList().add(entityGroup);
3643+
var value = new TestdataLavishValue();
3644+
solution.getValueList().add(value);
3645+
3646+
assertPrecompute(solution, List.of(new Pair<>(entityGroup, value)),
3647+
pf -> pf.forEachUnfiltered(TestdataLavishEntity.class)
3648+
.join(TestdataLavishValue.class)
3649+
.filter((entity, joinedValue) -> entity.getEntityGroup() != null)
3650+
.map((entity, joinedValue) -> entity.getEntityGroup(),
3651+
(entity, joinedValue) -> joinedValue)
3652+
.distinct());
3653+
}
33903654
}

0 commit comments

Comments
 (0)