|
34 | 34 | import ai.timefold.solver.core.api.score.stream.ConstraintCollectors; |
35 | 35 | import ai.timefold.solver.core.api.score.stream.ConstraintJustification; |
36 | 36 | import ai.timefold.solver.core.api.score.stream.DefaultConstraintJustification; |
| 37 | +import ai.timefold.solver.core.api.score.stream.Joiners; |
| 38 | +import ai.timefold.solver.core.api.score.stream.PrecomputeFactory; |
| 39 | +import ai.timefold.solver.core.api.score.stream.bi.BiConstraintStream; |
37 | 40 | import ai.timefold.solver.core.impl.score.director.InnerScoreDirector; |
38 | 41 | import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStreamTest; |
39 | 42 | import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamFunctionalTest; |
40 | 43 | import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamImplSupport; |
| 44 | +import ai.timefold.solver.core.impl.util.Pair; |
41 | 45 | import ai.timefold.solver.core.testdomain.TestdataEntity; |
42 | 46 | import ai.timefold.solver.core.testdomain.list.unassignedvar.TestdataAllowsUnassignedValuesListEntity; |
43 | 47 | import ai.timefold.solver.core.testdomain.list.unassignedvar.TestdataAllowsUnassignedValuesListSolution; |
@@ -3307,8 +3311,91 @@ public void joinerEqualsAndSameness() { |
3307 | 3311 | assertMatch(entity3, entity2)); |
3308 | 3312 | } |
3309 | 3313 |
|
| 3314 | + @Override |
| 3315 | + @TestTemplate |
| 3316 | + public void precompute_filter_0_changed() { |
| 3317 | + var solution = TestdataLavishSolution.generateSolution(); |
| 3318 | + var entityGroup = new TestdataLavishEntityGroup("MyEntityGroup"); |
| 3319 | + var valueGroup = new TestdataLavishValueGroup("MyValueGroup"); |
| 3320 | + solution.getEntityGroupList().add(entityGroup); |
| 3321 | + solution.getValueGroupList().add(valueGroup); |
| 3322 | + |
| 3323 | + var value1 = Mockito.spy(new TestdataLavishValue("MyValue 1", valueGroup)); |
| 3324 | + solution.getValueList().add(value1); |
| 3325 | + var value2 = Mockito.spy(new TestdataLavishValue("MyValue 2", valueGroup)); |
| 3326 | + solution.getValueList().add(value2); |
| 3327 | + var value3 = Mockito.spy(new TestdataLavishValue("MyValue 3", null)); |
| 3328 | + solution.getValueList().add(value3); |
| 3329 | + |
| 3330 | + var entity1 = Mockito.spy(new TestdataLavishEntity("MyEntity 1", entityGroup, value1)); |
| 3331 | + solution.getEntityList().add(entity1); |
| 3332 | + var entity2 = new TestdataLavishEntity("MyEntity 2", entityGroup, value1); |
| 3333 | + solution.getEntityList().add(entity2); |
| 3334 | + var entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), |
| 3335 | + value1); |
| 3336 | + solution.getEntityList().add(entity3); |
| 3337 | + |
| 3338 | + var scoreDirector = |
| 3339 | + buildScoreDirector(factory -> factory.precompute(data -> data.forEachUnfiltered(TestdataLavishEntity.class) |
| 3340 | + .join(TestdataLavishValue.class) |
| 3341 | + .filter((entity, value) -> entity.getEntityGroup() == entityGroup |
| 3342 | + && value.getValueGroup() == valueGroup)) |
| 3343 | + .filter((entity, value) -> entity.getValue() == value1) |
| 3344 | + .penalize(SimpleScore.ONE) |
| 3345 | + .asConstraint(TEST_CONSTRAINT_NAME)); |
| 3346 | + |
| 3347 | + // From scratch |
| 3348 | + Mockito.reset(entity1); |
| 3349 | + scoreDirector.setWorkingSolution(solution); |
| 3350 | + assertScore(scoreDirector, |
| 3351 | + assertMatch(entity1, value1), |
| 3352 | + assertMatch(entity1, value2), |
| 3353 | + assertMatch(entity2, value1), |
| 3354 | + assertMatch(entity2, value2)); |
| 3355 | + Mockito.verify(entity1, Mockito.atLeastOnce()).getEntityGroup(); |
| 3356 | + |
| 3357 | + // Incrementally update a variable |
| 3358 | + Mockito.reset(entity1); |
| 3359 | + scoreDirector.beforeVariableChanged(entity1, "value"); |
| 3360 | + entity1.setValue(solution.getFirstValue()); |
| 3361 | + scoreDirector.afterVariableChanged(entity1, "value"); |
| 3362 | + assertScore(scoreDirector, |
| 3363 | + assertMatch(entity2, value1), |
| 3364 | + assertMatch(entity2, value2)); |
| 3365 | + Mockito.verify(entity1, Mockito.never()).getEntityGroup(); |
| 3366 | + |
| 3367 | + // Incrementally update a fact |
| 3368 | + scoreDirector.beforeProblemPropertyChanged(entity3); |
| 3369 | + entity3.setEntityGroup(entityGroup); |
| 3370 | + scoreDirector.afterProblemPropertyChanged(entity3); |
| 3371 | + assertScore(scoreDirector, |
| 3372 | + assertMatch(entity2, value1), |
| 3373 | + assertMatch(entity2, value2), |
| 3374 | + assertMatch(entity3, value1), |
| 3375 | + assertMatch(entity3, value2)); |
| 3376 | + |
| 3377 | + // Remove entity |
| 3378 | + scoreDirector.beforeEntityRemoved(entity3); |
| 3379 | + solution.getEntityList().remove(entity3); |
| 3380 | + scoreDirector.afterEntityRemoved(entity3); |
| 3381 | + assertScore(scoreDirector, |
| 3382 | + assertMatch(entity2, value1), |
| 3383 | + assertMatch(entity2, value2)); |
| 3384 | + |
| 3385 | + // Add it back again, to make sure it was properly removed before |
| 3386 | + scoreDirector.beforeEntityAdded(entity3); |
| 3387 | + solution.getEntityList().add(entity3); |
| 3388 | + scoreDirector.afterEntityAdded(entity3); |
| 3389 | + assertScore(scoreDirector, |
| 3390 | + assertMatch(entity2, value1), |
| 3391 | + assertMatch(entity2, value2), |
| 3392 | + assertMatch(entity3, value1), |
| 3393 | + assertMatch(entity3, value2)); |
| 3394 | + } |
| 3395 | + |
| 3396 | + @Override |
3310 | 3397 | @TestTemplate |
3311 | | - public void precompute_join_filter_map_entity_right() { |
| 3398 | + public void precompute_filter_1_changed() { |
3312 | 3399 | var solution = TestdataLavishSolution.generateSolution(); |
3313 | 3400 | var entityGroup = new TestdataLavishEntityGroup("MyEntityGroup"); |
3314 | 3401 | var valueGroup = new TestdataLavishValueGroup("MyValueGroup"); |
@@ -3387,4 +3474,181 @@ public void precompute_join_filter_map_entity_right() { |
3387 | 3474 | assertMatch(value1, entity3), |
3388 | 3475 | assertMatch(value2, entity3)); |
3389 | 3476 | } |
| 3477 | + |
| 3478 | + private <A, B> void assertPrecompute(TestdataLavishSolution solution, |
| 3479 | + List<Pair<A, B>> expectedValues, |
| 3480 | + Function<PrecomputeFactory, BiConstraintStream<A, B>> entityStreamSupplier) { |
| 3481 | + var scoreDirector = |
| 3482 | + buildScoreDirector(factory -> factory.precompute(entityStreamSupplier) |
| 3483 | + .ifExists(TestdataLavishEntity.class) |
| 3484 | + .penalize(SimpleScore.ONE) |
| 3485 | + .asConstraint(TEST_CONSTRAINT_NAME)); |
| 3486 | + |
| 3487 | + // From scratch |
| 3488 | + scoreDirector.setWorkingSolution(solution); |
| 3489 | + assertScore(scoreDirector); |
| 3490 | + |
| 3491 | + for (var entity : solution.getEntityList()) { |
| 3492 | + scoreDirector.beforeVariableChanged(entity, "value"); |
| 3493 | + entity.setValue(solution.getFirstValue()); |
| 3494 | + scoreDirector.afterVariableChanged(entity, "value"); |
| 3495 | + } |
| 3496 | + |
| 3497 | + assertScore(scoreDirector, expectedValues.stream() |
| 3498 | + .map(pair -> new Object[] { pair.key(), pair.value() }) |
| 3499 | + .map(AbstractConstraintStreamTest::assertMatch) |
| 3500 | + .toArray(AssertableMatch[]::new)); |
| 3501 | + } |
| 3502 | + |
| 3503 | + @Override |
| 3504 | + @TestTemplate |
| 3505 | + public void precompute_ifExists() { |
| 3506 | + var solution = TestdataLavishSolution.generateEmptySolution(); |
| 3507 | + var entityWithoutGroup = new TestdataLavishEntity(); |
| 3508 | + var entityWithGroup = new TestdataLavishEntity(); |
| 3509 | + var entityGroup = new TestdataLavishEntityGroup(); |
| 3510 | + entityWithGroup.setEntityGroup(entityGroup); |
| 3511 | + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); |
| 3512 | + solution.getEntityGroupList().add(entityGroup); |
| 3513 | + var value = new TestdataLavishValue(); |
| 3514 | + solution.getValueList().add(value); |
| 3515 | + |
| 3516 | + assertPrecompute(solution, List.of(new Pair<>(entityWithGroup, value)), |
| 3517 | + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) |
| 3518 | + .join(TestdataLavishValue.class) |
| 3519 | + .ifExists(TestdataLavishEntityGroup.class, Joiners.equal( |
| 3520 | + (a, b) -> a.getEntityGroup(), Function.identity()))); |
| 3521 | + } |
| 3522 | + |
| 3523 | + @Override |
| 3524 | + @TestTemplate |
| 3525 | + public void precompute_ifNotExists() { |
| 3526 | + var solution = TestdataLavishSolution.generateEmptySolution(); |
| 3527 | + var entityWithoutGroup = new TestdataLavishEntity(); |
| 3528 | + var entityWithGroup = new TestdataLavishEntity(); |
| 3529 | + var entityGroup = new TestdataLavishEntityGroup(); |
| 3530 | + entityWithGroup.setEntityGroup(entityGroup); |
| 3531 | + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); |
| 3532 | + solution.getEntityGroupList().add(entityGroup); |
| 3533 | + |
| 3534 | + var value = new TestdataLavishValue(); |
| 3535 | + solution.getValueList().add(value); |
| 3536 | + |
| 3537 | + assertPrecompute(solution, List.of(new Pair<>(entityWithoutGroup, value)), |
| 3538 | + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) |
| 3539 | + .join(TestdataLavishValue.class) |
| 3540 | + .ifNotExists(TestdataLavishEntityGroup.class, Joiners.equal( |
| 3541 | + (a, b) -> a.getEntityGroup(), Function.identity()))); |
| 3542 | + } |
| 3543 | + |
| 3544 | + @Override |
| 3545 | + @TestTemplate |
| 3546 | + public void precompute_groupBy() { |
| 3547 | + var solution = TestdataLavishSolution.generateEmptySolution(); |
| 3548 | + var entityWithoutGroup = new TestdataLavishEntity(); |
| 3549 | + var entityWithGroup = new TestdataLavishEntity(); |
| 3550 | + var entityGroup = new TestdataLavishEntityGroup(); |
| 3551 | + entityWithGroup.setEntityGroup(entityGroup); |
| 3552 | + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); |
| 3553 | + solution.getEntityGroupList().add(entityGroup); |
| 3554 | + |
| 3555 | + var value = new TestdataLavishValue(); |
| 3556 | + solution.getValueList().add(value); |
| 3557 | + |
| 3558 | + assertPrecompute(solution, List.of(new Pair<>(entityGroup, 1)), |
| 3559 | + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) |
| 3560 | + .filter(entity -> entity.getEntityGroup() != null) |
| 3561 | + .groupBy(TestdataLavishEntity::getEntityGroup, ConstraintCollectors.count())); |
| 3562 | + } |
| 3563 | + |
| 3564 | + @Override |
| 3565 | + @TestTemplate |
| 3566 | + public void precompute_flattenLast() { |
| 3567 | + var solution = TestdataLavishSolution.generateEmptySolution(); |
| 3568 | + var entityWithoutGroup = new TestdataLavishEntity(); |
| 3569 | + var entityWithGroup = new TestdataLavishEntity(); |
| 3570 | + var entityGroup = new TestdataLavishEntityGroup(); |
| 3571 | + entityWithGroup.setEntityGroup(entityGroup); |
| 3572 | + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); |
| 3573 | + solution.getEntityGroupList().add(entityGroup); |
| 3574 | + var value = new TestdataLavishValue(); |
| 3575 | + solution.getValueList().add(value); |
| 3576 | + |
| 3577 | + assertPrecompute(solution, List.of(new Pair<>(entityWithoutGroup, value), |
| 3578 | + new Pair<>(entityWithGroup, value)), |
| 3579 | + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) |
| 3580 | + .groupBy(ConstraintCollectors.toList()) |
| 3581 | + .flattenLast(entityList -> entityList) |
| 3582 | + .join(TestdataLavishValue.class)); |
| 3583 | + } |
| 3584 | + |
| 3585 | + @Override |
| 3586 | + @TestTemplate |
| 3587 | + public void precompute_map() { |
| 3588 | + var solution = TestdataLavishSolution.generateEmptySolution(); |
| 3589 | + var entityWithoutGroup = new TestdataLavishEntity(); |
| 3590 | + var entityWithGroup1 = new TestdataLavishEntity(); |
| 3591 | + var entityWithGroup2 = new TestdataLavishEntity(); |
| 3592 | + var entityGroup = new TestdataLavishEntityGroup(); |
| 3593 | + entityWithGroup1.setEntityGroup(entityGroup); |
| 3594 | + entityWithGroup2.setEntityGroup(entityGroup); |
| 3595 | + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2)); |
| 3596 | + solution.getEntityGroupList().add(entityGroup); |
| 3597 | + var value = new TestdataLavishValue(); |
| 3598 | + solution.getValueList().add(value); |
| 3599 | + |
| 3600 | + assertPrecompute(solution, List.of(new Pair<>(entityGroup, value), |
| 3601 | + new Pair<>(entityGroup, value)), |
| 3602 | + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) |
| 3603 | + .join(TestdataLavishValue.class) |
| 3604 | + .filter((entity, joinedValue) -> entity.getEntityGroup() != null) |
| 3605 | + .map((entity, joinedValue) -> entity.getEntityGroup(), |
| 3606 | + (entity, joinedValue) -> joinedValue)); |
| 3607 | + } |
| 3608 | + |
| 3609 | + @Override |
| 3610 | + @TestTemplate |
| 3611 | + public void precompute_concat() { |
| 3612 | + var solution = TestdataLavishSolution.generateEmptySolution(); |
| 3613 | + var entityWithoutGroup = new TestdataLavishEntity(); |
| 3614 | + var entityWithGroup = new TestdataLavishEntity(); |
| 3615 | + var entityGroup = new TestdataLavishEntityGroup(); |
| 3616 | + entityWithGroup.setEntityGroup(entityGroup); |
| 3617 | + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); |
| 3618 | + solution.getEntityGroupList().add(entityGroup); |
| 3619 | + var value = new TestdataLavishValue(); |
| 3620 | + solution.getValueList().add(value); |
| 3621 | + |
| 3622 | + assertPrecompute(solution, List.of(new Pair<>(entityWithoutGroup, value), new Pair<>(entityWithGroup, value)), |
| 3623 | + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) |
| 3624 | + .join(TestdataLavishValue.class) |
| 3625 | + .filter((entity, joinedValue) -> entity.getEntityGroup() == null) |
| 3626 | + .concat(pf.forEachUnfiltered(TestdataLavishEntity.class) |
| 3627 | + .join(TestdataLavishValue.class) |
| 3628 | + .filter((entity, joinedValue) -> entity.getEntityGroup() != null))); |
| 3629 | + } |
| 3630 | + |
| 3631 | + @Override |
| 3632 | + @TestTemplate |
| 3633 | + public void precompute_distinct() { |
| 3634 | + var solution = TestdataLavishSolution.generateEmptySolution(); |
| 3635 | + var entityWithoutGroup = new TestdataLavishEntity(); |
| 3636 | + var entityWithGroup1 = new TestdataLavishEntity(); |
| 3637 | + var entityWithGroup2 = new TestdataLavishEntity(); |
| 3638 | + var entityGroup = new TestdataLavishEntityGroup(); |
| 3639 | + entityWithGroup1.setEntityGroup(entityGroup); |
| 3640 | + entityWithGroup2.setEntityGroup(entityGroup); |
| 3641 | + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2)); |
| 3642 | + solution.getEntityGroupList().add(entityGroup); |
| 3643 | + var value = new TestdataLavishValue(); |
| 3644 | + solution.getValueList().add(value); |
| 3645 | + |
| 3646 | + assertPrecompute(solution, List.of(new Pair<>(entityGroup, value)), |
| 3647 | + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) |
| 3648 | + .join(TestdataLavishValue.class) |
| 3649 | + .filter((entity, joinedValue) -> entity.getEntityGroup() != null) |
| 3650 | + .map((entity, joinedValue) -> entity.getEntityGroup(), |
| 3651 | + (entity, joinedValue) -> joinedValue) |
| 3652 | + .distinct()); |
| 3653 | + } |
3390 | 3654 | } |
0 commit comments