diff --git a/core/src/main/java/ai/timefold/solver/core/api/score/stream/ConstraintFactory.java b/core/src/main/java/ai/timefold/solver/core/api/score/stream/ConstraintFactory.java index 54e121dee2..7dd386052a 100644 --- a/core/src/main/java/ai/timefold/solver/core/api/score/stream/ConstraintFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/api/score/stream/ConstraintFactory.java @@ -15,19 +15,19 @@ import ai.timefold.solver.core.api.score.stream.bi.BiJoiner; import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream; -import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.NullMarked; /** * The factory to create every {@link ConstraintStream} (for example with {@link #forEach(Class)}) * which ends in a {@link Constraint} returned by {@link ConstraintProvider#defineConstraints(ConstraintFactory)}. */ +@NullMarked public interface ConstraintFactory { /** * @deprecated Do not rely on any constraint package in user code. */ @Deprecated(forRemoval = true, since = "1.13.0") - @NonNull String getDefaultConstraintPackage(); // ************************************************************************ @@ -65,7 +65,7 @@ public interface ConstraintFactory { * * @param the type of the matched problem fact or {@link PlanningEntity planning entity} */ - @NonNull UniConstraintStream forEach(@NonNull Class sourceClass); + UniConstraintStream forEach(Class sourceClass); /** * As defined by {@link #forEachIncludingUnassigned(Class)}. @@ -73,7 +73,7 @@ public interface ConstraintFactory { * @deprecated Use {@link #forEachIncludingUnassigned(Class)} instead. */ @Deprecated(forRemoval = true, since = "1.8.0") - default @NonNull UniConstraintStream forEachIncludingNullVars(@NonNull Class sourceClass) { + default UniConstraintStream forEachIncludingNullVars(Class sourceClass) { return forEachIncludingUnassigned(sourceClass); } @@ -86,7 +86,7 @@ public interface ConstraintFactory { * * @param the type of the matched problem fact or {@link PlanningEntity planning entity} */ - @NonNull UniConstraintStream forEachIncludingUnassigned(@NonNull Class sourceClass); + UniConstraintStream forEachIncludingUnassigned(Class sourceClass); /** * As defined by {@link #forEach(Class)}, @@ -98,7 +98,7 @@ public interface ConstraintFactory { * * @param the type of the matched problem fact or {@link PlanningEntity planning entity} */ - @NonNull UniConstraintStream forEachUnfiltered(@NonNull Class sourceClass); + UniConstraintStream forEachUnfiltered(Class sourceClass); /** * Create a new {@link BiConstraintStream} for every unique combination of A and another A with a higher {@link PlanningId}. @@ -114,7 +114,7 @@ public interface ConstraintFactory { * @param the type of the matched problem fact or {@link PlanningEntity planning entity} * @return a stream that matches every unique combination of A and another A */ - default @NonNull BiConstraintStream forEachUniquePair(@NonNull Class sourceClass) { + default BiConstraintStream forEachUniquePair(Class sourceClass) { return forEachUniquePair(sourceClass, new BiJoiner[0]); } @@ -135,8 +135,8 @@ public interface ConstraintFactory { * @param the type of the matched problem fact or {@link PlanningEntity planning entity} * @return a stream that matches every unique combination of A and another A for which the {@link BiJoiner} is true */ - default @NonNull BiConstraintStream forEachUniquePair(@NonNull Class sourceClass, - @NonNull BiJoiner joiner) { + default BiConstraintStream forEachUniquePair(Class sourceClass, + BiJoiner joiner) { return forEachUniquePair(sourceClass, new BiJoiner[] { joiner }); } @@ -147,9 +147,9 @@ public interface ConstraintFactory { * @return a stream that matches every unique combination of A and another A for which all the * {@link BiJoiner joiners} are true */ - default @NonNull BiConstraintStream forEachUniquePair(@NonNull Class sourceClass, - @NonNull BiJoiner joiner1, - @NonNull BiJoiner joiner2) { + default BiConstraintStream forEachUniquePair(Class sourceClass, + BiJoiner joiner1, + BiJoiner joiner2) { return forEachUniquePair(sourceClass, new BiJoiner[] { joiner1, joiner2 }); } @@ -160,9 +160,9 @@ public interface ConstraintFactory { * @return a stream that matches every unique combination of A and another A for which all the * {@link BiJoiner joiners} are true */ - default @NonNull BiConstraintStream forEachUniquePair(@NonNull Class sourceClass, - @NonNull BiJoiner joiner1, @NonNull BiJoiner joiner2, - @NonNull BiJoiner joiner3) { + default BiConstraintStream forEachUniquePair(Class sourceClass, + BiJoiner joiner1, BiJoiner joiner2, + BiJoiner joiner3) { return forEachUniquePair(sourceClass, new BiJoiner[] { joiner1, joiner2, joiner3 }); } @@ -173,9 +173,9 @@ public interface ConstraintFactory { * @return a stream that matches every unique combination of A and another A for which all the * {@link BiJoiner joiners} are true */ - default @NonNull BiConstraintStream forEachUniquePair(@NonNull Class sourceClass, - @NonNull BiJoiner joiner1, @NonNull BiJoiner joiner2, - @NonNull BiJoiner joiner3, @NonNull BiJoiner joiner4) { + default BiConstraintStream forEachUniquePair(Class sourceClass, + BiJoiner joiner1, BiJoiner joiner2, + BiJoiner joiner3, BiJoiner joiner4) { return forEachUniquePair(sourceClass, new BiJoiner[] { joiner1, joiner2, joiner3, joiner4 }); } @@ -190,7 +190,59 @@ public interface ConstraintFactory { * @return a stream that matches every unique combination of A and another A for which all the * {@link BiJoiner joiners} are true */ - @NonNull BiConstraintStream forEachUniquePair(@NonNull Class sourceClass, @NonNull BiJoiner... joiners); + BiConstraintStream forEachUniquePair(Class sourceClass, BiJoiner... joiners); + + // ************************************************************************ + // staticData + //************************************************************************ + + /** + * Computes and caches the tuples that would be produced by the given stream. + *

+ * IMPORTANT: As this is cached, it is vital the stream does not reference any variables + * (genuine or otherwise), as a score corruption would occur. + *

+ * For example, if employee is a {@link PlanningVariable} on Shift (a {@link PlanningEntity}), + * and start/end are facts on Shift, the following Constraint would cause a score corruption: + * + *

+     * BiConstraintStream<Shift, Shift> overlappingShifts(PrecomputeFactory precomputeFactory) {
+     *     return precomputeFactory.forEachUnfiltered(Shift.class)
+     *             .join(Shift.class,
+     *                     Joiners.overlapping(Shift::getStart, Shift::getEnd),
+     *                     Joiners.equal(Shift::getEmployee))
+     *             .filter((left, right) -> left != right);
+     * }
+     *
+     * Constraint noOverlappingShifts(ConstraintFactory constraintFactory) {
+     *     return constraintFactory.precompute(this::overlappingShifts)
+     *             .penalize(HardSoftScore.ONE_HARD)
+     *             .asConstraint("Overlapping shifts");
+     * }
+     * 
+ *

+ * You can (and should) use variables after the precompute. So the example above + * can be rewritten correctly like this and would not cause score corruptions: + *

+ * + *

+     * BiConstraintStream<Shift, Shift> overlappingShifts(PrecomputeFactory precomputeFactory) {
+     *     return precomputeFactory.forEachUnfiltered(Shift.class)
+     *             .join(Shift.class,
+     *                     Joiners.overlapping(Shift::getStart, Shift::getEnd))
+     *             .filter((left, right) -> left != right);
+     * }
+     *
+     * Constraint noOverlappingShifts(ConstraintFactory constraintFactory) {
+     *     return constraintFactory.precompute(this::overlappingShifts)
+     *             .filter((left, right) -> left.getEmployee() != null && left.getEmployee().equals(right.getEmployee()))
+     *             .penalize(HardSoftScore.ONE_HARD)
+     *             .asConstraint("Overlapping shifts");
+     * }
+     * 
+ */ + Stream_ + precompute(Function precomputeSupplier); // ************************************************************************ // from* (deprecated) @@ -232,7 +284,7 @@ public interface ConstraintFactory { * which both allow and don't allow unassigned values. */ @Deprecated(forRemoval = true) -
@NonNull UniConstraintStream from(@NonNull Class fromClass); + UniConstraintStream from(Class fromClass); /** * This method is deprecated. @@ -250,8 +302,7 @@ public interface ConstraintFactory { * @deprecated Prefer {@link #forEachIncludingUnassigned(Class)}. */ @Deprecated(forRemoval = true) - @NonNull - UniConstraintStream fromUnfiltered(@NonNull Class fromClass); + UniConstraintStream fromUnfiltered(Class fromClass); /** * This method is deprecated. @@ -277,7 +328,7 @@ public interface ConstraintFactory { * @return a stream that matches every unique combination of A and another A */ @Deprecated(forRemoval = true) - default @NonNull BiConstraintStream fromUniquePair(@NonNull Class fromClass) { + default BiConstraintStream fromUniquePair(Class fromClass) { return fromUniquePair(fromClass, new BiJoiner[0]); } @@ -308,7 +359,7 @@ public interface ConstraintFactory { * @return a stream that matches every unique combination of A and another A for which the {@link BiJoiner} is true */ @Deprecated(forRemoval = true) - default @NonNull BiConstraintStream fromUniquePair(@NonNull Class fromClass, @NonNull BiJoiner joiner) { + default BiConstraintStream fromUniquePair(Class fromClass, BiJoiner joiner) { return fromUniquePair(fromClass, new BiJoiner[] { joiner }); } @@ -329,8 +380,8 @@ public interface ConstraintFactory { * {@link BiJoiner joiners} are true */ @Deprecated(forRemoval = true) - default @NonNull BiConstraintStream fromUniquePair(@NonNull Class fromClass, @NonNull BiJoiner joiner1, - @NonNull BiJoiner joiner2) { + default BiConstraintStream fromUniquePair(Class fromClass, BiJoiner joiner1, + BiJoiner joiner2) { return fromUniquePair(fromClass, new BiJoiner[] { joiner1, joiner2 }); } @@ -352,8 +403,8 @@ public interface ConstraintFactory { * {@link BiJoiner joiners} are true */ @Deprecated(forRemoval = true) - default @NonNull BiConstraintStream fromUniquePair(@NonNull Class fromClass, @NonNull BiJoiner joiner1, - @NonNull BiJoiner joiner2, @NonNull BiJoiner joiner3) { + default BiConstraintStream fromUniquePair(Class fromClass, BiJoiner joiner1, + BiJoiner joiner2, BiJoiner joiner3) { return fromUniquePair(fromClass, new BiJoiner[] { joiner1, joiner2, joiner3 }); } @@ -374,9 +425,9 @@ public interface ConstraintFactory { * which both allow and don't allow unassigned values. */ @Deprecated(forRemoval = true) - default @NonNull BiConstraintStream fromUniquePair(@NonNull Class fromClass, - @NonNull BiJoiner joiner1, @NonNull BiJoiner joiner2, - @NonNull BiJoiner joiner3, @NonNull BiJoiner joiner4) { + default BiConstraintStream fromUniquePair(Class fromClass, + BiJoiner joiner1, BiJoiner joiner2, + BiJoiner joiner3, BiJoiner joiner4) { return fromUniquePair(fromClass, new BiJoiner[] { joiner1, joiner2, joiner3, joiner4 }); } @@ -401,7 +452,6 @@ public interface ConstraintFactory { * which both allow and don't allow unassigned values. */ @Deprecated(forRemoval = true) - @NonNull - BiConstraintStream fromUniquePair(@NonNull Class fromClass, @NonNull BiJoiner... joiners); + BiConstraintStream fromUniquePair(Class fromClass, BiJoiner... joiners); } diff --git a/core/src/main/java/ai/timefold/solver/core/api/score/stream/PrecomputeFactory.java b/core/src/main/java/ai/timefold/solver/core/api/score/stream/PrecomputeFactory.java new file mode 100644 index 0000000000..4583391bbd --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/api/score/stream/PrecomputeFactory.java @@ -0,0 +1,44 @@ +package ai.timefold.solver.core.api.score.stream; + +import java.util.function.Function; + +import ai.timefold.solver.core.api.domain.entity.PlanningEntity; +import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream; + +/** + * Similar to a {@link ConstraintFactory}, except its methods (and the + * {@link ConstraintStream}s they return) do not apply any automatic + * filters (like those mentioned in {@link ConstraintFactory#forEach(Class)}). + */ +public interface PrecomputeFactory { + /** + * As defined by {@link ConstraintFactory#forEachUnfiltered(Class)}, + * with the additional change of any joining stream will also be unfiltered. + *

+ * For example, + *

+ * + *

+     * precomputeFactory.forEachUnfiltered(Shift.class)
+     *         .join(Shift.class, Joiners.equal(Shift::getLocation));
+     * 
+ *

+ * Would roughly be equivalent to + *

+ * + *

+     * constraintFactory.forEachUnfiltered(Shift.class)
+     *         .join(constraintFactory.forEachUnfiltered(Shift.class),
+     *                 Joiners.equal(Shift::getLocation));
+     * 
+ *

+ * Important: no variables can be referenced in any operations performed + * by the returned {@link ConstraintStream}, otherwise a score corruption will + * occur. + * See the note in {@link ConstraintFactory#precompute(Function)} for + * more details. + * + * @param the type of the matched problem fact or {@link PlanningEntity planning entity} + */ + UniConstraintStream forEachUnfiltered(Class sourceClass); +} diff --git a/core/src/main/java/ai/timefold/solver/core/api/score/stream/bi/BiConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/api/score/stream/bi/BiConstraintStream.java index f72e84d96a..677bde791f 100644 --- a/core/src/main/java/ai/timefold/solver/core/api/score/stream/bi/BiConstraintStream.java +++ b/core/src/main/java/ai/timefold/solver/core/api/score/stream/bi/BiConstraintStream.java @@ -33,6 +33,7 @@ import ai.timefold.solver.core.api.score.stream.tri.TriConstraintStream; import ai.timefold.solver.core.api.score.stream.tri.TriJoiner; import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream; +import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStream; import ai.timefold.solver.core.impl.util.ConstantLambdaUtils; import org.jspecify.annotations.NonNull; @@ -1582,9 +1583,21 @@ TriConstraintStream concat(@NonNull TriConstraintStream ot @NonNull Function paddingFunction) { var firstStream = this; var remapped = firstStream.map(ConstantLambdaUtils.biPickFirst()); - var secondStream = getConstraintFactory().forEach(otherClass) - .ifNotExists(remapped, Joiners.equal()); - return firstStream.concat(secondStream, paddingFunction); + + if (firstStream instanceof AbstractConstraintStream abstractConstraintStream) { + var secondStream = switch (abstractConstraintStream.getRetrievalSemantics()) { + case STANDARD, LEGACY -> getConstraintFactory().forEach(otherClass); + case PRECOMPUTE -> getConstraintFactory().forEachUnfiltered(otherClass); + }; + return firstStream.concat(secondStream.ifNotExists(remapped, Joiners.equal()), + paddingFunction); + } else { + throw new IllegalStateException(""" + Impossible state: the %s class (%s) does not extend %s. + %s are not expected to be implemented by the user. + """.formatted(ConstraintStream.class.getSimpleName(), this.getClass().getSimpleName(), + AbstractConstraintStream.class.getSimpleName(), ConstraintStream.class.getSimpleName())); + } } // ************************************************************************ diff --git a/core/src/main/java/ai/timefold/solver/core/api/score/stream/quad/QuadConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/api/score/stream/quad/QuadConstraintStream.java index 0a136666ba..7fa0434678 100644 --- a/core/src/main/java/ai/timefold/solver/core/api/score/stream/quad/QuadConstraintStream.java +++ b/core/src/main/java/ai/timefold/solver/core/api/score/stream/quad/QuadConstraintStream.java @@ -34,6 +34,7 @@ import ai.timefold.solver.core.api.score.stream.penta.PentaJoiner; import ai.timefold.solver.core.api.score.stream.tri.TriConstraintStream; import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream; +import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStream; import ai.timefold.solver.core.impl.util.ConstantLambdaUtils; import org.jspecify.annotations.NonNull; @@ -1083,9 +1084,21 @@ QuadConstraintStream concat(@NonNull TriConstraintStream ot @NonNull Function paddingFunctionD) { var firstStream = this; var remapped = firstStream.map(ConstantLambdaUtils.quadPickFirst()); - var secondStream = getConstraintFactory().forEach(otherClass) - .ifNotExists(remapped, Joiners.equal()); - return firstStream.concat(secondStream, paddingFunctionB, paddingFunctionC, paddingFunctionD); + + if (firstStream instanceof AbstractConstraintStream abstractConstraintStream) { + var secondStream = switch (abstractConstraintStream.getRetrievalSemantics()) { + case STANDARD, LEGACY -> getConstraintFactory().forEach(otherClass); + case PRECOMPUTE -> getConstraintFactory().forEachUnfiltered(otherClass); + }; + return firstStream.concat(secondStream.ifNotExists(remapped, Joiners.equal()), + paddingFunctionB, paddingFunctionC, paddingFunctionD); + } else { + throw new IllegalStateException(""" + Impossible state: the %s class (%s) does not extend %s. + %s are not expected to be implemented by the user. + """.formatted(ConstraintStream.class.getSimpleName(), this.getClass().getSimpleName(), + AbstractConstraintStream.class.getSimpleName(), ConstraintStream.class.getSimpleName())); + } } // ************************************************************************ diff --git a/core/src/main/java/ai/timefold/solver/core/api/score/stream/tri/TriConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/api/score/stream/tri/TriConstraintStream.java index 3d9222da56..3a97cef99e 100644 --- a/core/src/main/java/ai/timefold/solver/core/api/score/stream/tri/TriConstraintStream.java +++ b/core/src/main/java/ai/timefold/solver/core/api/score/stream/tri/TriConstraintStream.java @@ -34,6 +34,7 @@ import ai.timefold.solver.core.api.score.stream.quad.QuadConstraintStream; import ai.timefold.solver.core.api.score.stream.quad.QuadJoiner; import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream; +import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStream; import ai.timefold.solver.core.impl.util.ConstantLambdaUtils; import org.jspecify.annotations.NonNull; @@ -1294,9 +1295,21 @@ TriConstraintStream concat(@NonNull BiConstraintStream otherStrea @NonNull Function paddingFunctionB, @NonNull Function paddingFunctionC) { var firstStream = this; var remapped = firstStream.map(ConstantLambdaUtils.triPickFirst()); - var secondStream = getConstraintFactory().forEach(otherClass) - .ifNotExists(remapped, Joiners.equal()); - return firstStream.concat(secondStream, paddingFunctionB, paddingFunctionC); + + if (firstStream instanceof AbstractConstraintStream abstractConstraintStream) { + var secondStream = switch (abstractConstraintStream.getRetrievalSemantics()) { + case STANDARD, LEGACY -> getConstraintFactory().forEach(otherClass); + case PRECOMPUTE -> getConstraintFactory().forEachUnfiltered(otherClass); + }; + return firstStream.concat(secondStream.ifNotExists(remapped, Joiners.equal()), + paddingFunctionB, paddingFunctionC); + } else { + throw new IllegalStateException(""" + Impossible state: the %s class (%s) does not extend %s. + %s are not expected to be implemented by the user. + """.formatted(ConstraintStream.class.getSimpleName(), this.getClass().getSimpleName(), + AbstractConstraintStream.class.getSimpleName(), ConstraintStream.class.getSimpleName())); + } } // ************************************************************************ diff --git a/core/src/main/java/ai/timefold/solver/core/api/score/stream/uni/UniConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/api/score/stream/uni/UniConstraintStream.java index 9d31a52e14..5ffe2b56c7 100644 --- a/core/src/main/java/ai/timefold/solver/core/api/score/stream/uni/UniConstraintStream.java +++ b/core/src/main/java/ai/timefold/solver/core/api/score/stream/uni/UniConstraintStream.java @@ -33,6 +33,7 @@ import ai.timefold.solver.core.api.score.stream.bi.BiJoiner; import ai.timefold.solver.core.api.score.stream.quad.QuadConstraintStream; import ai.timefold.solver.core.api.score.stream.tri.TriConstraintStream; +import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStream; import org.jspecify.annotations.NonNull; @@ -1727,9 +1728,19 @@ default UniConstraintStream ifNotExistsOtherIncludingNullVars(Class otherC */ default @NonNull UniConstraintStream complement(@NonNull Class otherClass) { var firstStream = this; - var secondStream = getConstraintFactory().forEach(otherClass) - .ifNotExists(firstStream, Joiners.equal()); - return firstStream.concat(secondStream); + if (firstStream instanceof AbstractConstraintStream abstractConstraintStream) { + var secondStream = switch (abstractConstraintStream.getRetrievalSemantics()) { + case STANDARD, LEGACY -> getConstraintFactory().forEach(otherClass); + case PRECOMPUTE -> getConstraintFactory().forEachUnfiltered(otherClass); + }; + return firstStream.concat(secondStream.ifNotExists(firstStream, Joiners.equal())); + } else { + throw new IllegalStateException(""" + Impossible state: the %s class (%s) does not extend %s. + %s are not expected to be implemented by the user. + """.formatted(ConstraintStream.class.getSimpleName(), this.getClass().getSimpleName(), + AbstractConstraintStream.class.getSimpleName(), ConstraintStream.class.getSimpleName())); + } } // ************************************************************************ diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/AbstractSession.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/AbstractSession.java index 38b65835ae..f4f1320ff8 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/AbstractSession.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/AbstractSession.java @@ -3,15 +3,15 @@ import java.util.IdentityHashMap; import java.util.Map; -import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode; -import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode.LifecycleOperation; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode.LifecycleOperation; public abstract class AbstractSession { private final NodeNetwork nodeNetwork; - private final Map, AbstractForEachUniNode[]> insertEffectiveClassToNodeArrayMap; - private final Map, AbstractForEachUniNode[]> updateEffectiveClassToNodeArrayMap; - private final Map, AbstractForEachUniNode[]> retractEffectiveClassToNodeArrayMap; + private final Map, BavetRootNode[]> insertEffectiveClassToNodeArrayMap; + private final Map, BavetRootNode[]> updateEffectiveClassToNodeArrayMap; + private final Map, BavetRootNode[]> retractEffectiveClassToNodeArrayMap; protected AbstractSession(NodeNetwork nodeNetwork) { this.nodeNetwork = nodeNetwork; @@ -22,13 +22,13 @@ protected AbstractSession(NodeNetwork nodeNetwork) { public final void insert(Object fact) { var factClass = fact.getClass(); - for (var node : findNodes(factClass, LifecycleOperation.INSERT)) { + for (var node : findNodes(factClass, BavetRootNode.LifecycleOperation.INSERT)) { node.insert(fact); } } @SuppressWarnings("unchecked") - private AbstractForEachUniNode[] findNodes(Class factClass, LifecycleOperation lifecycleOperation) { + private BavetRootNode[] findNodes(Class factClass, LifecycleOperation lifecycleOperation) { var effectiveClassToNodeArrayMap = switch (lifecycleOperation) { case INSERT -> insertEffectiveClassToNodeArrayMap; case UPDATE -> updateEffectiveClassToNodeArrayMap; @@ -37,9 +37,9 @@ private AbstractForEachUniNode[] findNodes(Class factClass, Lifecycle // Map.computeIfAbsent() would have created lambdas on the hot path, this will not. var nodeArray = effectiveClassToNodeArrayMap.get(factClass); if (nodeArray == null) { - nodeArray = nodeNetwork.getForEachNodes(factClass) + nodeArray = nodeNetwork.getRootNodesAcceptingType(factClass) .filter(node -> node.supports(lifecycleOperation)) - .toArray(AbstractForEachUniNode[]::new); + .toArray(BavetRootNode[]::new); effectiveClassToNodeArrayMap.put(factClass, nodeArray); } return nodeArray; @@ -47,14 +47,14 @@ private AbstractForEachUniNode[] findNodes(Class factClass, Lifecycle public final void update(Object fact) { var factClass = fact.getClass(); - for (var node : findNodes(factClass, LifecycleOperation.UPDATE)) { + for (var node : findNodes(factClass, BavetRootNode.LifecycleOperation.UPDATE)) { node.update(fact); } } public final void retract(Object fact) { var factClass = fact.getClass(); - for (var node : findNodes(factClass, LifecycleOperation.RETRACT)) { + for (var node : findNodes(factClass, BavetRootNode.LifecycleOperation.RETRACT)) { node.retract(fact); } } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/NodeNetwork.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/NodeNetwork.java index 62077a3ed4..07d21536ab 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/NodeNetwork.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/NodeNetwork.java @@ -7,8 +7,8 @@ import java.util.stream.Stream; import ai.timefold.solver.core.api.domain.solution.PlanningSolution; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; import ai.timefold.solver.core.impl.bavet.common.Propagator; -import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode; /** * Represents Bavet's network of nodes, specific to a particular session. @@ -19,7 +19,7 @@ * @param layeredNodes nodes grouped first by their layer, then by their index within the layer; * propagation needs to happen in this order. */ -public record NodeNetwork(Map, List>> declaredClassToNodeMap, +public record NodeNetwork(Map, List>> declaredClassToNodeMap, Propagator[][] layeredNodes) { public static final NodeNetwork EMPTY = new NodeNetwork(Map.of(), new Propagator[0][0]); @@ -32,14 +32,19 @@ public int layerCount() { return layeredNodes.length; } - public Stream> getForEachNodes(Class factClass) { + public Stream> getRootNodes() { + return declaredClassToNodeMap.entrySet() + .stream() + .flatMap(entry -> entry.getValue().stream()); + } + + public Stream> getRootNodesAcceptingType(Class factClass) { // The node needs to match the fact, or the node needs to be applicable to the entire solution. // The latter is for FromSolution nodes. return declaredClassToNodeMap.entrySet() .stream() - .filter(entry -> factClass == PlanningSolution.class || entry.getKey().isAssignableFrom(factClass)) - .map(Map.Entry::getValue) - .flatMap(List::stream); + .flatMap(entry -> entry.getValue().stream()) + .filter(tupleSourceRoot -> factClass == PlanningSolution.class || tupleSourceRoot.allowsInstancesOf(factClass)); } public void settle() { diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/bi/PrecomputeBiNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/bi/PrecomputeBiNode.java new file mode 100644 index 0000000000..2656b71d93 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/bi/PrecomputeBiNode.java @@ -0,0 +1,28 @@ +package ai.timefold.solver.core.impl.bavet.bi; + +import ai.timefold.solver.core.impl.bavet.NodeNetwork; +import ai.timefold.solver.core.impl.bavet.common.AbstractPrecomputeNode; +import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; + +import org.jspecify.annotations.NullMarked; + +@NullMarked +public final class PrecomputeBiNode extends AbstractPrecomputeNode> { + private final int outputStoreSize; + + public PrecomputeBiNode(NodeNetwork nodeNetwork, + RecordingTupleLifecycle> recordingTupleNode, + int outputStoreSize, + TupleLifecycle> nextNodesTupleLifecycle, + Class[] sourceClasses) { + super(nodeNetwork, recordingTupleNode, nextNodesTupleLifecycle, sourceClasses); + this.outputStoreSize = outputStoreSize; + } + + @Override + protected BiTuple remapTuple(BiTuple tuple) { + return new BiTuple<>(tuple.factA, tuple.factB, outputStoreSize); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractNodeBuildHelper.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractNodeBuildHelper.java index 4bb364e006..7b3f605910 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractNodeBuildHelper.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractNodeBuildHelper.java @@ -16,7 +16,6 @@ import ai.timefold.solver.core.impl.bavet.common.tuple.LeftTupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.RightTupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; -import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode; public abstract class AbstractNodeBuildHelper { @@ -47,7 +46,7 @@ public void addNode(AbstractNode node, Stream_ creator) { public void addNode(AbstractNode node, Stream_ creator, Stream_ parent) { reversedNodeList.add(node); nodeCreatorMap.put(node, creator); - if (!(node instanceof AbstractForEachUniNode)) { + if (!(node instanceof BavetRootNode)) { if (parent == null) { throw new IllegalStateException("Impossible state: The node (%s) has no parent (%s)." .formatted(node, parent)); @@ -148,7 +147,7 @@ public AbstractNode findParentNode(Stream_ childNodeCreator) { } public static NodeNetwork buildNodeNetwork(List nodeList, - Map, List>> declaredClassToNodeMap) { + Map, List>> declaredClassToNodeMap) { var layerMap = new TreeMap>(); for (var node : nodeList) { layerMap.computeIfAbsent(node.getLayerIndex(), k -> new ArrayList<>()) @@ -206,7 +205,7 @@ public > List long determineLayerIndex(AbstractNode node, AbstractNodeBuildHelper buildHelper) { - if (node instanceof AbstractForEachUniNode) { // ForEach nodes, and only they, are in layer 0. + if (node instanceof BavetRootNode) { // Root nodes, and only they, are in layer 0. return 0; } else if (node instanceof AbstractTwoInputNode joinNode) { var nodeCreator = (BavetStreamBinaryOperation) buildHelper.getNodeCreatingStream(joinNode); diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractPrecomputeNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractPrecomputeNode.java new file mode 100644 index 0000000000..c38d548fb2 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractPrecomputeNode.java @@ -0,0 +1,78 @@ +package ai.timefold.solver.core.impl.bavet.common; + +import ai.timefold.solver.core.impl.bavet.NodeNetwork; +import ai.timefold.solver.core.impl.bavet.common.tuple.AbstractTuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +@NullMarked +public abstract class AbstractPrecomputeNode extends AbstractNode + implements BavetRootNode { + private final RecordAndReplayPropagator recordAndReplayPropagator; + private final Class[] sourceClasses; + + protected AbstractPrecomputeNode(NodeNetwork innerNodeNetwork, + RecordingTupleLifecycle recordingTupleLifecycle, + TupleLifecycle nextNodesTupleLifecycle, + Class[] sourceClasses) { + this.recordAndReplayPropagator = new RecordAndReplayPropagator<>(innerNodeNetwork, + recordingTupleLifecycle, + this::remapTuple, + nextNodesTupleLifecycle); + this.sourceClasses = sourceClasses; + } + + @Override + public final Propagator getPropagator() { + return recordAndReplayPropagator; + } + + @Override + public final boolean allowsInstancesOf(Class clazz) { + for (var sourceClass : sourceClasses) { + if (sourceClass.isAssignableFrom(clazz)) { + return true; + } + } + return false; + } + + @Override + public final Class[] getSourceClasses() { + return sourceClasses; + } + + @Override + public final boolean supports(BavetRootNode.LifecycleOperation lifecycleOperation) { + return true; + } + + @Override + public final void insert(@Nullable Object a) { + if (a == null) { + return; + } + recordAndReplayPropagator.insert(a); + } + + @Override + public final void update(@Nullable Object a) { + if (a == null) { + return; + } + recordAndReplayPropagator.update(a); + } + + @Override + public final void retract(@Nullable Object a) { + if (a == null) { + return; + } + recordAndReplayPropagator.retract(a); + } + + protected abstract Tuple_ remapTuple(Tuple_ tuple); +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/BavetRootNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/BavetRootNode.java new file mode 100644 index 0000000000..13aad93bf7 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/BavetRootNode.java @@ -0,0 +1,52 @@ +package ai.timefold.solver.core.impl.bavet.common; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +@NullMarked +public interface BavetRootNode { + void insert(@Nullable A a); + + void update(@Nullable A a); + + void retract(@Nullable A a); + + boolean allowsInstancesOf(Class clazz); + + Class[] getSourceClasses(); + + /** + * Determines if this node supports the given lifecycle operation. + * Unsupported nodes will not be called during that lifecycle operation. + * + * @param lifecycleOperation the lifecycle operation to check + * @return {@code true} if the given lifecycle operation is supported; otherwise, {@code false}. + */ + boolean supports(BavetRootNode.LifecycleOperation lifecycleOperation); + + /** + * Represents the various lifecycle operations that can be performed + * on tuples within a node in Bavet. + */ + enum LifecycleOperation { + /** + * Represents the operation of inserting a new tuple into the node. + * This operation is typically performed when a new fact is added to the working solution + * and needs to be propagated through the node network. + */ + INSERT, + /** + * Represents the operation of updating an existing tuple within the node. + * This operation is typically triggered when a fact in the working solution + * is modified, requiring the corresponding tuple to be updated and its changes + * propagated through the node network. + */ + UPDATE, + /** + * Represents the operation of retracting or removing an existing tuple from the node. + * This operation is typically used when a fact is removed from the working solution + * and its corresponding tuple needs to be removed from the node network. + */ + RETRACT + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/Propagator.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/Propagator.java index b3d5946541..2aa9958728 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/Propagator.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/Propagator.java @@ -38,7 +38,7 @@ * * @see PropagationQueue More information about propagation. */ -public sealed interface Propagator permits PropagationQueue { +public sealed interface Propagator permits PropagationQueue, RecordAndReplayPropagator { /** * Starts the propagation event. Must be followed by {@link #propagateUpdates()}. diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/RecordAndReplayPropagator.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/RecordAndReplayPropagator.java new file mode 100644 index 0000000000..74367fa2b5 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/RecordAndReplayPropagator.java @@ -0,0 +1,181 @@ +package ai.timefold.solver.core.impl.bavet.common; + +import java.util.ArrayList; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.UnaryOperator; + +import ai.timefold.solver.core.impl.bavet.NodeNetwork; +import ai.timefold.solver.core.impl.bavet.common.tuple.AbstractTuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleState; +import ai.timefold.solver.core.impl.util.CollectionUtils; + +import org.jspecify.annotations.NullMarked; + +/** + * The implementation records the tuples each object affects inside + * an internal {@link ai.timefold.solver.core.impl.bavet.NodeNetwork} and + * replays them on update. + * Used by {@link AbstractPrecomputeNode} to precompute constraint streams. + * + * @param + */ +@NullMarked +public final class RecordAndReplayPropagator + implements Propagator { + + private final Set retractQueue; + private final Set updateQueue; + private final Set insertQueue; + + private final NodeNetwork internalNodeNetwork; + private final RecordingTupleLifecycle recordingTupleLifecycle; + private final UnaryOperator internalTupleToOutputTupleMapper; + private final Map internalTupleToOutputTupleMap; + private final Map> objectToOutputTuplesMap; + + private final StaticPropagationQueue propagationQueue; + + public RecordAndReplayPropagator( + NodeNetwork internalNodeNetwork, + RecordingTupleLifecycle recordingTupleLifecycle, + UnaryOperator internalTupleToOutputTupleMapper, + TupleLifecycle nextNodesTupleLifecycle, int size) { + this.internalNodeNetwork = internalNodeNetwork; + this.recordingTupleLifecycle = recordingTupleLifecycle; + this.internalTupleToOutputTupleMapper = internalTupleToOutputTupleMapper; + this.internalTupleToOutputTupleMap = CollectionUtils.newIdentityHashMap(size); + this.objectToOutputTuplesMap = CollectionUtils.newIdentityHashMap(size); + + // Guesstimate that updates are dominant. + this.retractQueue = CollectionUtils.newIdentityHashSet(size / 20); + this.updateQueue = CollectionUtils.newIdentityHashSet((size / 20) * 18); + this.insertQueue = CollectionUtils.newIdentityHashSet(size / 20); + + this.propagationQueue = new StaticPropagationQueue<>(nextNodesTupleLifecycle); + } + + public RecordAndReplayPropagator( + NodeNetwork internalNodeNetwork, + RecordingTupleLifecycle recordingTupleLifecycle, + UnaryOperator internalTupleToOutputTupleMapper, + TupleLifecycle nextNodesTupleLifecycle) { + this(internalNodeNetwork, recordingTupleLifecycle, internalTupleToOutputTupleMapper, nextNodesTupleLifecycle, 1000); + } + + public void insert(Object object) { + // do not remove a retract of the same fact (a fact was updated) + insertQueue.add(object); + } + + public void update(Object object) { + updateQueue.add(object); + } + + public void retract(Object object) { + // remove an insert then retract (a fact was inserted but retracted before settling) + // do not remove a retract then insert (a fact was updated) + if (!insertQueue.remove(object)) { + retractQueue.add(object); + } + } + + @Override + public void propagateRetracts() { + if (!retractQueue.isEmpty() || !insertQueue.isEmpty()) { + updateQueue.removeAll(retractQueue); + updateQueue.removeAll(insertQueue); + // Do not remove queued retracts from inserts; if a fact property + // change, there will be both a retract and insert for that fact + invalidateCache(); + + retractQueue.forEach(this::retractFromInternalNodeNetwork); + insertQueue.forEach(this::insertIntoInternalNodeNetwork); + retractQueue.clear(); + insertQueue.clear(); + + // settle the inner node network, so the inserts/retracts do not interfere + // with the recording of the first object's tuples + internalNodeNetwork.settle(); + recalculateTuples(); + propagationQueue.propagateRetracts(); + } + } + + @Override + public void propagateUpdates() { + for (var update : updateQueue) { + for (var updatedTuple : objectToOutputTuplesMap.get(update)) { + propagationQueue.update(updatedTuple); + } + } + updateQueue.clear(); + propagationQueue.propagateUpdates(); + } + + @Override + public void propagateInserts() { + // propagateRetracts clears/process the insertQueue + propagationQueue.propagateInserts(); + } + + private void insertIfAbsent(Tuple_ tuple) { + var state = tuple.state; + if (state != TupleState.CREATING) { + propagationQueue.insert(tuple); + } + } + + private void retractIfPresent(Tuple_ tuple) { + var state = tuple.state; + if (state.isDirty()) { + if (state == TupleState.DYING || state == TupleState.ABORTING) { + // We already retracted this tuple from another list, so we + // don't need to do anything + return; + } + propagationQueue.retract(tuple, state == TupleState.CREATING ? TupleState.ABORTING : TupleState.DYING); + } else { + propagationQueue.retract(tuple, TupleState.DYING); + } + } + + private void insertIntoInternalNodeNetwork(Object toInsert) { + objectToOutputTuplesMap.put(toInsert, new ArrayList<>()); + internalNodeNetwork.getRootNodesAcceptingType(toInsert.getClass()) + .forEach(node -> ((BavetRootNode) node).insert(toInsert)); + } + + private void retractFromInternalNodeNetwork(Object toRetract) { + objectToOutputTuplesMap.remove(toRetract); + internalNodeNetwork.getRootNodesAcceptingType(toRetract.getClass()) + .forEach(node -> ((BavetRootNode) node).retract(toRetract)); + } + + private void invalidateCache() { + objectToOutputTuplesMap.values().stream().flatMap(List::stream).forEach(this::retractIfPresent); + internalTupleToOutputTupleMap.clear(); + } + + private void recalculateTuples() { + for (var mappedTupleEntry : objectToOutputTuplesMap.entrySet()) { + mappedTupleEntry.getValue().clear(); + var invalidated = mappedTupleEntry.getKey(); + try (var unusedActiveRecordingLifecycle = recordingTupleLifecycle.recordInto( + new TupleRecorder<>(mappedTupleEntry.getValue(), internalTupleToOutputTupleMapper, + (IdentityHashMap) internalTupleToOutputTupleMap))) { + // Do a fake update on the object and settle the network; this will update precisely the + // tuples mapped to this node, which will then be recorded + internalNodeNetwork.getRootNodesAcceptingType(invalidated.getClass()) + .forEach(node -> ((BavetRootNode) node).update(invalidated)); + internalNodeNetwork.settle(); + } + } + objectToOutputTuplesMap.values().stream().flatMap(List::stream).forEach(this::insertIfAbsent); + } + +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/TupleRecorder.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/TupleRecorder.java new file mode 100644 index 0000000000..be7c5694fe --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/TupleRecorder.java @@ -0,0 +1,18 @@ +package ai.timefold.solver.core.impl.bavet.common; + +import java.util.IdentityHashMap; +import java.util.List; +import java.util.function.UnaryOperator; + +import ai.timefold.solver.core.impl.bavet.common.tuple.AbstractTuple; + +import org.jspecify.annotations.NullMarked; + +@NullMarked +public record TupleRecorder(List recordedTupleList, + UnaryOperator mapper, + IdentityHashMap inputTupleToOutputTuple) { + public void recordTuple(Tuple_ tuple) { + recordedTupleList.add(inputTupleToOutputTuple.computeIfAbsent(tuple, mapper)); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RecordingTupleLifecycle.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RecordingTupleLifecycle.java new file mode 100644 index 0000000000..fbfb04ed3f --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RecordingTupleLifecycle.java @@ -0,0 +1,42 @@ +package ai.timefold.solver.core.impl.bavet.common.tuple; + +import ai.timefold.solver.core.impl.bavet.common.TupleRecorder; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +@NullMarked +public class RecordingTupleLifecycle implements TupleLifecycle, AutoCloseable { + @Nullable + TupleRecorder tupleRecorder; + + public RecordingTupleLifecycle recordInto(TupleRecorder tupleRecorder) { + this.tupleRecorder = tupleRecorder; + return this; + } + + @Override + public void close() { + this.tupleRecorder = null; + } + + @Override + public void insert(Tuple_ tuple) { + if (tupleRecorder != null) { + throw new IllegalStateException("Impossible state: tuple %s was inserted during recording".formatted(tuple)); + } + } + + @Override + public void update(Tuple_ tuple) { + if (tupleRecorder != null) { + tupleRecorder.recordTuple(tuple); + } + } + + @Override + public void retract(Tuple_ tuple) { + // Not illegal; a filter can retract a never inserted tuple on update, + // since it does not remember what tuples it accepted to save memory + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/TupleLifecycle.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/TupleLifecycle.java index a17752a3db..675e6f7164 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/TupleLifecycle.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/TupleLifecycle.java @@ -51,6 +51,10 @@ static TupleLifecycle> conditionally(TupleLifecycle< tuple -> predicate.test(tuple.factA, tuple.factB, tuple.factC, tuple.factD)); } + static TupleLifecycle recording() { + return new RecordingTupleLifecycle<>(); + } + void insert(Tuple_ tuple); void update(Tuple_ tuple); diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/quad/PrecomputeQuadNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/quad/PrecomputeQuadNode.java new file mode 100644 index 0000000000..e107d17596 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/quad/PrecomputeQuadNode.java @@ -0,0 +1,29 @@ +package ai.timefold.solver.core.impl.bavet.quad; + +import ai.timefold.solver.core.impl.bavet.NodeNetwork; +import ai.timefold.solver.core.impl.bavet.common.AbstractPrecomputeNode; +import ai.timefold.solver.core.impl.bavet.common.tuple.QuadTuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; + +import org.jspecify.annotations.NullMarked; + +@NullMarked +public final class PrecomputeQuadNode extends AbstractPrecomputeNode> { + private final int outputStoreSize; + + public PrecomputeQuadNode(NodeNetwork nodeNetwork, + RecordingTupleLifecycle> recordingTupleNode, + int outputStoreSize, + TupleLifecycle> nextNodesTupleLifecycle, + Class[] sourceClasses) { + super(nodeNetwork, recordingTupleNode, nextNodesTupleLifecycle, sourceClasses); + this.outputStoreSize = outputStoreSize; + } + + @Override + protected QuadTuple remapTuple(QuadTuple tuple) { + return new QuadTuple<>(tuple.factA, tuple.factB, tuple.factC, tuple.factD, + outputStoreSize); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/tri/PrecomputeTriNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/tri/PrecomputeTriNode.java new file mode 100644 index 0000000000..2e06304781 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/tri/PrecomputeTriNode.java @@ -0,0 +1,28 @@ +package ai.timefold.solver.core.impl.bavet.tri; + +import ai.timefold.solver.core.impl.bavet.NodeNetwork; +import ai.timefold.solver.core.impl.bavet.common.AbstractPrecomputeNode; +import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.TriTuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; + +import org.jspecify.annotations.NullMarked; + +@NullMarked +public final class PrecomputeTriNode extends AbstractPrecomputeNode> { + private final int outputStoreSize; + + public PrecomputeTriNode(NodeNetwork nodeNetwork, + RecordingTupleLifecycle> recordingTupleNode, + int outputStoreSize, + TupleLifecycle> nextNodesTupleLifecycle, + Class[] sourceClasses) { + super(nodeNetwork, recordingTupleNode, nextNodesTupleLifecycle, sourceClasses); + this.outputStoreSize = outputStoreSize; + } + + @Override + protected TriTuple remapTuple(TriTuple tuple) { + return new TriTuple<>(tuple.factA, tuple.factB, tuple.factC, outputStoreSize); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/AbstractForEachUniNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/AbstractForEachUniNode.java index 510b3586cc..ed857aa335 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/AbstractForEachUniNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/AbstractForEachUniNode.java @@ -4,6 +4,7 @@ import java.util.Map; import ai.timefold.solver.core.impl.bavet.common.AbstractNode; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; import ai.timefold.solver.core.impl.bavet.common.Propagator; import ai.timefold.solver.core.impl.bavet.common.StaticPropagationQueue; import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; @@ -24,6 +25,7 @@ @NullMarked public abstract sealed class AbstractForEachUniNode extends AbstractNode + implements BavetRootNode permits ForEachFilteredUniNode, ForEachUnfilteredUniNode { private final Class forEachClass; @@ -38,6 +40,17 @@ protected AbstractForEachUniNode(Class forEachClass, TupleLifecycle(nextNodesTupleLifecycle); } + @Override + public boolean allowsInstancesOf(Class clazz) { + return forEachClass.isAssignableFrom(clazz); + } + + @Override + public Class[] getSourceClasses() { + return new Class[] { forEachClass }; + } + + @Override public void insert(@Nullable A a) { var tuple = new UniTuple<>(a, outputStoreSize); var old = tupleMap.put(a, tuple); @@ -48,8 +61,6 @@ public void insert(@Nullable A a) { propagationQueue.insert(tuple); } - public abstract void update(@Nullable A a); - protected final void updateExisting(@Nullable A a, UniTuple tuple) { var state = tuple.state; if (state.isDirty()) { @@ -63,6 +74,7 @@ protected final void updateExisting(@Nullable A a, UniTuple tuple) { } } + @Override public void retract(@Nullable A a) { var tuple = tupleMap.remove(a); if (tuple == null) { @@ -94,45 +106,10 @@ public final Class getForEachClass() { return forEachClass; } - /** - * Determines if this node supports the given lifecycle operation. - * Unsupported nodes will not be called during that lifecycle operation. - * - * @param lifecycleOperation the lifecycle operation to check - * @return {@code true} if the given lifecycle operation is supported; otherwise, {@code false}. - */ - public abstract boolean supports(LifecycleOperation lifecycleOperation); - @Override public final String toString() { return "%s(%s)" .formatted(getClass().getSimpleName(), forEachClass.getSimpleName()); } - /** - * Represents the various lifecycle operations that can be performed - * on tuples within a node in Bavet. - */ - public enum LifecycleOperation { - /** - * Represents the operation of inserting a new tuple into the node. - * This operation is typically performed when a new fact is added to the working solution - * and needs to be propagated through the node network. - */ - INSERT, - /** - * Represents the operation of updating an existing tuple within the node. - * This operation is typically triggered when a fact in the working solution - * is modified, requiring the corresponding tuple to be updated and its changes - * propagated through the node network. - */ - UPDATE, - /** - * Represents the operation of retracting or removing an existing tuple from the node. - * This operation is typically used when a fact is removed from the working solution - * and its corresponding tuple needs to be removed from the node network. - */ - RETRACT - } - } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachFilteredUniNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachFilteredUniNode.java index 9514b2cf39..9f7b7fdc3c 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachFilteredUniNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachFilteredUniNode.java @@ -3,6 +3,7 @@ import java.util.Objects; import java.util.function.Predicate; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; @@ -51,7 +52,7 @@ public void retract(@Nullable A a) { } @Override - public boolean supports(LifecycleOperation lifecycleOperation) { + public boolean supports(BavetRootNode.LifecycleOperation lifecycleOperation) { return true; } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachUnfilteredUniNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachUnfilteredUniNode.java index 3660cad1e9..f08516bd1c 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachUnfilteredUniNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachUnfilteredUniNode.java @@ -1,5 +1,6 @@ package ai.timefold.solver.core.impl.bavet.uni; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; @@ -26,7 +27,7 @@ public void update(@Nullable A a) { } @Override - public boolean supports(LifecycleOperation lifecycleOperation) { + public boolean supports(BavetRootNode.LifecycleOperation lifecycleOperation) { return true; } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/PrecomputeUniNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/PrecomputeUniNode.java new file mode 100644 index 0000000000..8237b8407b --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/PrecomputeUniNode.java @@ -0,0 +1,28 @@ +package ai.timefold.solver.core.impl.bavet.uni; + +import ai.timefold.solver.core.impl.bavet.NodeNetwork; +import ai.timefold.solver.core.impl.bavet.common.AbstractPrecomputeNode; +import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; + +import org.jspecify.annotations.NullMarked; + +@NullMarked +public final class PrecomputeUniNode extends AbstractPrecomputeNode> { + private final int outputStoreSize; + + public PrecomputeUniNode(NodeNetwork nodeNetwork, + RecordingTupleLifecycle> recordingTupleNode, + int outputStoreSize, + TupleLifecycle> nextNodesTupleLifecycle, + Class[] sourceClasses) { + super(nodeNetwork, recordingTupleNode, nextNodesTupleLifecycle, sourceClasses); + this.outputStoreSize = outputStoreSize; + } + + @Override + protected UniTuple remapTuple(UniTuple tuple) { + return new UniTuple<>(tuple.factA, outputStoreSize); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/DatasetSessionFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/DatasetSessionFactory.java index 6d0a4f103d..879f50cffc 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/DatasetSessionFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/DatasetSessionFactory.java @@ -9,6 +9,7 @@ import ai.timefold.solver.core.impl.bavet.NodeNetwork; import ai.timefold.solver.core.impl.bavet.common.AbstractNodeBuildHelper; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.AbstractEnumeratingStream; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.DataNodeBuildHelper; @@ -38,7 +39,7 @@ public DatasetSession buildSession(SessionContext context) private NodeNetwork buildNodeNetwork(Set> enumeratingStreamSet, DataNodeBuildHelper buildHelper, Consumer nodeNetworkVisualizationConsumer) { - var declaredClassToNodeMap = new LinkedHashMap, List>>(); + var declaredClassToNodeMap = new LinkedHashMap, List>>(); var nodeList = buildHelper.buildNodeList(enumeratingStreamSet, buildHelper, AbstractEnumeratingStream::buildNode, node -> { if (!(node instanceof AbstractForEachUniNode forEachUniNode)) { diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/AbstractForEachEnumeratingStream.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/AbstractForEachEnumeratingStream.java index 3b64467bca..a3f46ef2d6 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/AbstractForEachEnumeratingStream.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/AbstractForEachEnumeratingStream.java @@ -1,6 +1,6 @@ package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.uni; -import static ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode.LifecycleOperation; +import static ai.timefold.solver.core.impl.bavet.common.BavetRootNode.LifecycleOperation; import java.util.Objects; import java.util.Set; diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/director/stream/BavetConstraintStreamScoreDirector.java b/core/src/main/java/ai/timefold/solver/core/impl/score/director/stream/BavetConstraintStreamScoreDirector.java index 3cd43f8132..9e16407b75 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/director/stream/BavetConstraintStreamScoreDirector.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/director/stream/BavetConstraintStreamScoreDirector.java @@ -180,11 +180,17 @@ public void afterProblemFactAdded(Object problemFact) { super.afterProblemFactAdded(problemFact); } - // public void beforeProblemPropertyChanged(Object problemFactOrEntity) // Do nothing + @Override + public void beforeProblemPropertyChanged(Object problemFactOrEntity) { + // Since this is called when a fact (not a variable) changes, + // we need to retract and reinsert to update cached static data + super.beforeProblemPropertyChanged(problemFactOrEntity); + session.retract(problemFactOrEntity); + } @Override public void afterProblemPropertyChanged(Object problemFactOrEntity) { - session.update(problemFactOrEntity); + session.insert(problemFactOrEntity); super.afterProblemPropertyChanged(problemFactOrEntity); } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintFactory.java index 11720db829..b0704c3fca 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintFactory.java @@ -7,7 +7,9 @@ import java.util.function.Function; import java.util.function.Predicate; +import ai.timefold.solver.core.api.score.stream.ConstraintStream; import ai.timefold.solver.core.api.score.stream.Joiners; +import ai.timefold.solver.core.api.score.stream.PrecomputeFactory; import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream; import ai.timefold.solver.core.config.solver.EnvironmentMode; import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; @@ -15,14 +17,28 @@ import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor; import ai.timefold.solver.core.impl.domain.variable.declarative.ConsistencyTracker; import ai.timefold.solver.core.impl.score.constraint.ConstraintMatchPolicy; +import ai.timefold.solver.core.impl.score.stream.bavet.bi.BavetAbstractBiConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.bi.BavetPrecomputeBiConstraintStream; import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.bridge.BavetAftBridgeBiConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.common.bridge.BavetAftBridgeQuadConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.common.bridge.BavetAftBridgeTriConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.common.bridge.BavetAftBridgeUniConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.quad.BavetAbstractQuadConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.quad.BavetPrecomputeQuadConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.tri.BavetAbstractTriConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.tri.BavetPrecomputeTriConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.uni.BavetAbstractUniConstraintStream; import ai.timefold.solver.core.impl.score.stream.bavet.uni.BavetForEachUniConstraintStream; +import ai.timefold.solver.core.impl.score.stream.bavet.uni.BavetPrecomputeUniConstraintStream; import ai.timefold.solver.core.impl.score.stream.common.ForEachFilteringCriteria; import ai.timefold.solver.core.impl.score.stream.common.InnerConstraintFactory; import ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics; -import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; +@NullMarked public final class BavetConstraintFactory extends InnerConstraintFactory> { @@ -56,12 +72,12 @@ public BavetConstraintFactory(SolutionDescriptor solutionDescriptor, } } - private static String determineDefaultConstraintPackage(Package pkg) { + private static String determineDefaultConstraintPackage(@Nullable Package pkg) { var asString = pkg == null ? "" : pkg.getName(); return determineDefaultConstraintPackage(asString); } - private static String determineDefaultConstraintPackage(String constraintPackage) { + private static String determineDefaultConstraintPackage(@Nullable String constraintPackage) { return constraintPackage == null || constraintPackage.isEmpty() ? DEFAULT_CONSTRAINT_PACKAGE : constraintPackage; } @@ -101,25 +117,30 @@ public > Stream_ share( // Required for node sharing, since using a lambda will create different instances private record ForEachFilteringCriteriaPredicateFunction(EntityDescriptor entityDescriptor, ForEachFilteringCriteria criteria) implements Function, Predicate> { - public Predicate apply(@NonNull ConstraintNodeBuildHelper helper) { + public Predicate apply(ConstraintNodeBuildHelper helper) { return helper.getForEachPredicateForEntityDescriptorAndCriteria(entityDescriptor, criteria); } } - private @NonNull UniConstraintStream forEachForCriteria(@NonNull Class sourceClass, + private UniConstraintStream forEachForCriteria(Class sourceClass, ForEachFilteringCriteria criteria) { + return forEachForCriteria(sourceClass, criteria, RetrievalSemantics.STANDARD); + } + + private UniConstraintStream forEachForCriteria(Class sourceClass, + ForEachFilteringCriteria criteria, RetrievalSemantics retrievalSemantics) { assertValidFromType(sourceClass); var entityDescriptor = solutionDescriptor.findEntityDescriptor(sourceClass); if (entityDescriptor == null || criteria == ForEachFilteringCriteria.ALL) { // Not genuine or shadow entity, or filtering was not requested; no need for filtering. - return share(new BavetForEachUniConstraintStream<>(this, sourceClass, null, RetrievalSemantics.STANDARD)); + return share(new BavetForEachUniConstraintStream<>(this, sourceClass, null, retrievalSemantics)); } var listVariableDescriptor = solutionDescriptor.getListVariableDescriptor(); if (listVariableDescriptor == null || !listVariableDescriptor.acceptsValueType(sourceClass)) { // No applicable list variable; don't need to check inverse relationships. return share(new BavetForEachUniConstraintStream<>(this, sourceClass, new ForEachFilteringCriteriaPredicateFunction<>(entityDescriptor, criteria), - RetrievalSemantics.STANDARD)); + retrievalSemantics)); } var entityClass = listVariableDescriptor.getEntityDescriptor().getEntityClass(); if (entityClass == sourceClass) { @@ -138,35 +159,39 @@ public Predicate apply(@NonNull ConstraintNodeBuildHelper helpe } else { // We have the inverse relation variable, so we can read its value directly. return share(new BavetForEachUniConstraintStream<>(this, sourceClass, new ForEachFilteringCriteriaPredicateFunction<>(entityDescriptor, criteria), - RetrievalSemantics.STANDARD)); + retrievalSemantics)); } } @Override - public @NonNull UniConstraintStream forEach(@NonNull Class sourceClass) { + public UniConstraintStream forEach(Class sourceClass) { return forEachForCriteria(sourceClass, ForEachFilteringCriteria.ASSIGNED_AND_CONSISTENT); } @Override - public @NonNull UniConstraintStream forEachIncludingUnassigned(@NonNull Class sourceClass) { + public UniConstraintStream forEachIncludingUnassigned(Class sourceClass) { return forEachForCriteria(sourceClass, ForEachFilteringCriteria.CONSISTENT); } @Override - public @NonNull UniConstraintStream forEachUnfiltered(@NonNull Class sourceClass) { + public UniConstraintStream forEachUnfiltered(Class sourceClass) { return forEachForCriteria(sourceClass, ForEachFilteringCriteria.ALL); } + UniConstraintStream forEachUnfilteredStatic(Class sourceClass) { + return forEachForCriteria(sourceClass, ForEachFilteringCriteria.ALL, RetrievalSemantics.PRECOMPUTE); + } + // Required for node sharing, since using a lambda will create different instances private record PredicateSupplier( Predicate suppliedPredicate) implements Function, Predicate> { - public Predicate apply(@NonNull ConstraintNodeBuildHelper helper) { + public Predicate apply(ConstraintNodeBuildHelper helper) { return suppliedPredicate; } } @Override - public @NonNull UniConstraintStream from(@NonNull Class fromClass) { + public UniConstraintStream from(Class fromClass) { assertValidFromType(fromClass); var entityDescriptor = solutionDescriptor.findEntityDescriptor(fromClass); if (entityDescriptor != null && entityDescriptor.isGenuine()) { @@ -180,7 +205,40 @@ public Predicate apply(@NonNull ConstraintNodeBuildHelper helpe } @Override - public @NonNull UniConstraintStream fromUnfiltered(@NonNull Class fromClass) { + @SuppressWarnings("unchecked") + public Stream_ + precompute(Function precomputeSupplier) { + var bavetStream = Objects.requireNonNull(precomputeSupplier.apply(new BavetStaticDataFactory<>(this))); + // TODO: Use switch here in JDK 21 + if (bavetStream instanceof BavetAbstractUniConstraintStream uniStream) { + var out = new BavetPrecomputeUniConstraintStream<>(this, + (BavetAbstractUniConstraintStream) uniStream); + return (Stream_) share(new BavetAftBridgeUniConstraintStream<>(this, out), + out::setAftBridge); + } else if (bavetStream instanceof BavetAbstractBiConstraintStream biStream) { + var out = new BavetPrecomputeBiConstraintStream<>(this, + (BavetAbstractBiConstraintStream) biStream); + return (Stream_) share(new BavetAftBridgeBiConstraintStream<>(this, out), + out::setAftBridge); + } else if (bavetStream instanceof BavetAbstractTriConstraintStream triStream) { + var out = new BavetPrecomputeTriConstraintStream<>(this, + (BavetAbstractTriConstraintStream) triStream); + return (Stream_) share(new BavetAftBridgeTriConstraintStream<>(this, out), + out::setAftBridge); + } else if (bavetStream instanceof BavetAbstractQuadConstraintStream quadStream) { + var out = new BavetPrecomputeQuadConstraintStream<>(this, + (BavetAbstractQuadConstraintStream) quadStream); + return (Stream_) share(new BavetAftBridgeQuadConstraintStream<>(this, out), + out::setAftBridge); + } else { + throw new IllegalStateException( + "impossible state: the supplier (%s) returned a stream (%s) that not an instance of any Bavet ConstraintStream" + .formatted(precomputeSupplier, bavetStream)); + } + } + + @Override + public UniConstraintStream fromUnfiltered(Class fromClass) { assertValidFromType(fromClass); return share(new BavetForEachUniConstraintStream<>(this, fromClass, null, RetrievalSemantics.LEGACY)); } @@ -199,7 +257,7 @@ public EnvironmentMode getEnvironmentMode() { } @Override - public @NonNull String getDefaultConstraintPackage() { + public String getDefaultConstraintPackage() { return defaultConstraintPackage; } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java index afeb0562f6..c38a493cde 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java @@ -15,6 +15,7 @@ import ai.timefold.solver.core.impl.bavet.NodeNetwork; import ai.timefold.solver.core.impl.bavet.common.AbstractNodeBuildHelper; import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode; import ai.timefold.solver.core.impl.bavet.visual.NodeGraph; import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor; @@ -124,23 +125,33 @@ private static > NodeNetwork buildNodeNe AbstractScoreInliner scoreInliner, Consumer nodeNetworkVisualizationConsumer) { var buildHelper = new ConstraintNodeBuildHelper<>(consistencyTracker, constraintStreamSet, scoreInliner); - var declaredClassToNodeMap = new LinkedHashMap, List>>(); + var declaredClassToNodeMap = new LinkedHashMap, List>>(); var nodeList = buildHelper.buildNodeList(constraintStreamSet, buildHelper, BavetAbstractConstraintStream::buildNode, node -> { - if (!(node instanceof AbstractForEachUniNode forEachUniNode)) { + if (!(node instanceof BavetRootNode tupleSourceRoot)) { return; } - var forEachClass = forEachUniNode.getForEachClass(); - var forEachUniNodeList = - declaredClassToNodeMap.computeIfAbsent(forEachClass, k -> new ArrayList<>(2)); - if (forEachUniNodeList.size() == 3) { - // Each class can have at most three forEach nodes: one including everything, one including consistent + null vars, the last consistent + no null vars. - throw new IllegalStateException( - "Impossible state: For class (%s) there are already 3 nodes (%s), not adding another (%s)." - .formatted(forEachClass, forEachUniNodeList, forEachUniNode)); + + if (tupleSourceRoot instanceof AbstractForEachUniNode forEachUniNode) { + var forEachClass = forEachUniNode.getForEachClass(); + var forEachUniNodeList = + declaredClassToNodeMap.computeIfAbsent(forEachClass, k -> new ArrayList<>(2)); + if (forEachUniNodeList.stream().filter(sourceNode -> sourceNode instanceof AbstractForEachUniNode) + .count() == 3) { + // Each class can have at most three forEach nodes: one including everything, one including consistent + null vars, the last consistent + no null vars. + throw new IllegalStateException( + "Impossible state: For class (%s) there are already 3 nodes (%s), not adding another (%s)." + .formatted(forEachClass, forEachUniNodeList, forEachUniNode)); + } + forEachUniNodeList.add(forEachUniNode); + } else { + for (var sourceClass : tupleSourceRoot.getSourceClasses()) { + var forEachUniNodeList = + declaredClassToNodeMap.computeIfAbsent(sourceClass, k -> new ArrayList<>(2)); + forEachUniNodeList.add(tupleSourceRoot); + } } - forEachUniNodeList.add(forEachUniNode); }); if (nodeNetworkVisualizationConsumer != null) { var constraintSet = scoreInliner.getConstraints(); diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetStaticDataFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetStaticDataFactory.java new file mode 100644 index 0000000000..7d95f28423 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetStaticDataFactory.java @@ -0,0 +1,12 @@ +package ai.timefold.solver.core.impl.score.stream.bavet; + +import ai.timefold.solver.core.api.score.stream.PrecomputeFactory; +import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream; + +public record BavetStaticDataFactory( + BavetConstraintFactory constraintFactory) implements PrecomputeFactory { + @Override + public UniConstraintStream forEachUnfiltered(Class sourceClass) { + return constraintFactory.forEachUnfilteredStatic(sourceClass); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/bi/BavetPrecomputeBiConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/bi/BavetPrecomputeBiConstraintStream.java new file mode 100644 index 0000000000..0c5d3dcd3e --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/bi/BavetPrecomputeBiConstraintStream.java @@ -0,0 +1,72 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.bi; + +import java.util.Objects; +import java.util.Set; + +import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.bavet.bi.PrecomputeBiNode; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.TupleSource; +import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory; +import ai.timefold.solver.core.impl.score.stream.bavet.common.BavetPrecomputeBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.bridge.BavetAftBridgeBiConstraintStream; +import ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics; + +public class BavetPrecomputeBiConstraintStream extends BavetAbstractBiConstraintStream + implements TupleSource { + private final BavetAbstractConstraintStream recordingPrecomputedConstraintStream; + private BavetAftBridgeBiConstraintStream aftStream; + + public BavetPrecomputeBiConstraintStream( + BavetConstraintFactory constraintFactory, + BavetAbstractConstraintStream precomputedConstraintStream) { + super(constraintFactory, RetrievalSemantics.STANDARD); + this.recordingPrecomputedConstraintStream = new BavetRecordingBiConstraintStream<>(constraintFactory, + precomputedConstraintStream); + precomputedConstraintStream.getChildStreamList().add(recordingPrecomputedConstraintStream); + } + + public void setAftBridge(BavetAftBridgeBiConstraintStream aftStream) { + this.aftStream = aftStream; + } + + @Override + public > void buildNode(ConstraintNodeBuildHelper buildHelper) { + var precomputeBuildHelper = new BavetPrecomputeBuildHelper>(recordingPrecomputedConstraintStream); + var outputStoreSize = buildHelper.extractTupleStoreSize(aftStream); + + buildHelper.addNode(new PrecomputeBiNode<>(precomputeBuildHelper.getNodeNetwork(), + precomputeBuildHelper.getRecordingTupleLifecycle(), + outputStoreSize, + buildHelper.getAggregatedTupleLifecycle(aftStream.getChildStreamList()), + precomputeBuildHelper.getSourceClasses()), + this); + } + + @Override + public void collectActiveConstraintStreams(Set> constraintStreamSet) { + constraintStreamSet.add(this); + } + + // ************************************************************************ + // Equality for node sharing + // ************************************************************************ + + @Override + public int hashCode() { + return Objects.hash(recordingPrecomputedConstraintStream); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (o instanceof BavetPrecomputeBiConstraintStream other) { + return recordingPrecomputedConstraintStream.equals(other.recordingPrecomputedConstraintStream); + } else { + return false; + } + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/bi/BavetRecordingBiConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/bi/BavetRecordingBiConstraintStream.java new file mode 100644 index 0000000000..b3a43ed664 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/bi/BavetRecordingBiConstraintStream.java @@ -0,0 +1,41 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.bi; + +import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory; +import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; + +public class BavetRecordingBiConstraintStream extends BavetAbstractBiConstraintStream { + protected BavetRecordingBiConstraintStream( + BavetConstraintFactory constraintFactory, + BavetAbstractConstraintStream parent) { + super(constraintFactory, parent); + } + + @Override + public > void buildNode(ConstraintNodeBuildHelper buildHelper) { + assertEmptyChildStreamList(); + buildHelper.putInsertUpdateRetract(this, TupleLifecycle.recording()); + } + + // ************************************************************************ + // Equality for node sharing + // ************************************************************************ + + @Override + public int hashCode() { + return parent.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (o instanceof BavetRecordingBiConstraintStream other) { + return parent.equals(other.parent); + } else { + return false; + } + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/BavetPrecomputeBuildHelper.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/BavetPrecomputeBuildHelper.java new file mode 100644 index 0000000000..8ef81ea1d1 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/BavetPrecomputeBuildHelper.java @@ -0,0 +1,96 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.common; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; + +import ai.timefold.solver.core.api.score.stream.ConstraintFactory; +import ai.timefold.solver.core.api.score.stream.ConstraintStream; +import ai.timefold.solver.core.api.score.stream.PrecomputeFactory; +import ai.timefold.solver.core.impl.bavet.NodeNetwork; +import ai.timefold.solver.core.impl.bavet.common.AbstractNodeBuildHelper; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; +import ai.timefold.solver.core.impl.bavet.common.tuple.AbstractTuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle; +import ai.timefold.solver.core.impl.domain.variable.declarative.ConsistencyTracker; +import ai.timefold.solver.core.impl.score.buildin.SimpleScoreDefinition; +import ai.timefold.solver.core.impl.score.constraint.ConstraintMatchPolicy; +import ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics; +import ai.timefold.solver.core.impl.score.stream.common.inliner.AbstractScoreInliner; + +public final class BavetPrecomputeBuildHelper { + private final NodeNetwork nodeNetwork; + private final RecordingTupleLifecycle recordingTupleLifecycle; + private final Class[] sourceClasses; + + public BavetPrecomputeBuildHelper( + BavetAbstractConstraintStream recordingPrecomputeConstraintStream) { + if (recordingPrecomputeConstraintStream.getRetrievalSemantics() != RetrievalSemantics.PRECOMPUTE) { + throw new IllegalStateException( + "Impossible state: %s is not %s but is instead %s. Maybe you accidentally used a %s from %s instead of %s?" + .formatted(RetrievalSemantics.class.getSimpleName(), RetrievalSemantics.PRECOMPUTE, + recordingPrecomputeConstraintStream.getRetrievalSemantics(), + ConstraintStream.class.getSimpleName(), ConstraintFactory.class.getSimpleName(), + PrecomputeFactory.class.getSimpleName())); + } + + var streamList = new ArrayList>(); + var queue = new ArrayDeque>(); + queue.addLast(recordingPrecomputeConstraintStream); + + while (!queue.isEmpty()) { + var current = queue.pollFirst(); + streamList.add(current); + if (current instanceof BavetConstraintStreamBinaryOperation binaryOperation) { + queue.addLast((BavetAbstractConstraintStream) binaryOperation.getLeftParent()); + queue.addLast((BavetAbstractConstraintStream) binaryOperation.getRightParent()); + } else { + if (current.getParent() != null) { + queue.addLast(current.getParent()); + } + } + } + Collections.reverse(streamList); + var streamSet = new LinkedHashSet<>(streamList); + + var buildHelper = new ConstraintNodeBuildHelper<>(new ConsistencyTracker<>(), streamSet, + AbstractScoreInliner.buildScoreInliner(new SimpleScoreDefinition(), Collections.emptyMap(), + ConstraintMatchPolicy.DISABLED)); + + var declaredClassToNodeMap = new LinkedHashMap, List>>(); + var nodeList = buildHelper.buildNodeList(streamSet, buildHelper, + BavetAbstractConstraintStream::buildNode, + node -> { + if (!(node instanceof BavetRootNode sourceRootNode)) { + return; + } + var nodeSourceClasses = sourceRootNode.getSourceClasses(); + for (Class nodeSourceClass : nodeSourceClasses) { + var sourceNodeList = declaredClassToNodeMap.computeIfAbsent(nodeSourceClass, k -> new ArrayList<>(2)); + sourceNodeList.add(sourceRootNode); + } + }); + + this.nodeNetwork = AbstractNodeBuildHelper.buildNodeNetwork(nodeList, declaredClassToNodeMap); + this.recordingTupleLifecycle = + (RecordingTupleLifecycle) buildHelper + .getAggregatedTupleLifecycle(List.of(recordingPrecomputeConstraintStream)); + this.sourceClasses = declaredClassToNodeMap.keySet().toArray(new Class[0]); + } + + public NodeNetwork getNodeNetwork() { + return nodeNetwork; + } + + public RecordingTupleLifecycle getRecordingTupleLifecycle() { + return recordingTupleLifecycle; + } + + public Class[] getSourceClasses() { + return sourceClasses; + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/quad/BavetPrecomputeQuadConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/quad/BavetPrecomputeQuadConstraintStream.java new file mode 100644 index 0000000000..b2e616f79e --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/quad/BavetPrecomputeQuadConstraintStream.java @@ -0,0 +1,73 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.quad; + +import java.util.Objects; +import java.util.Set; + +import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.TupleSource; +import ai.timefold.solver.core.impl.bavet.common.tuple.QuadTuple; +import ai.timefold.solver.core.impl.bavet.quad.PrecomputeQuadNode; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory; +import ai.timefold.solver.core.impl.score.stream.bavet.common.BavetPrecomputeBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.bridge.BavetAftBridgeQuadConstraintStream; +import ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics; + +public class BavetPrecomputeQuadConstraintStream + extends BavetAbstractQuadConstraintStream + implements TupleSource { + private final BavetAbstractConstraintStream recordingPrecomputedConstraintStream; + private BavetAftBridgeQuadConstraintStream aftStream; + + public BavetPrecomputeQuadConstraintStream( + BavetConstraintFactory constraintFactory, + BavetAbstractConstraintStream precomputedConstraintStream) { + super(constraintFactory, RetrievalSemantics.STANDARD); + this.recordingPrecomputedConstraintStream = new BavetRecordingQuadConstraintStream<>(constraintFactory, + precomputedConstraintStream); + precomputedConstraintStream.getChildStreamList().add(recordingPrecomputedConstraintStream); + } + + public void setAftBridge(BavetAftBridgeQuadConstraintStream aftStream) { + this.aftStream = aftStream; + } + + @Override + public > void buildNode(ConstraintNodeBuildHelper buildHelper) { + var precomputeBuildHelper = new BavetPrecomputeBuildHelper>(recordingPrecomputedConstraintStream); + var outputStoreSize = buildHelper.extractTupleStoreSize(aftStream); + + buildHelper.addNode(new PrecomputeQuadNode<>(precomputeBuildHelper.getNodeNetwork(), + precomputeBuildHelper.getRecordingTupleLifecycle(), + outputStoreSize, + buildHelper.getAggregatedTupleLifecycle(aftStream.getChildStreamList()), + precomputeBuildHelper.getSourceClasses()), + this); + } + + @Override + public void collectActiveConstraintStreams(Set> constraintStreamSet) { + constraintStreamSet.add(this); + } + + // ************************************************************************ + // Equality for node sharing + // ************************************************************************ + + @Override + public int hashCode() { + return Objects.hash(recordingPrecomputedConstraintStream); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (o instanceof BavetPrecomputeQuadConstraintStream other) { + return recordingPrecomputedConstraintStream.equals(other.recordingPrecomputedConstraintStream); + } else { + return false; + } + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/quad/BavetRecordingQuadConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/quad/BavetRecordingQuadConstraintStream.java new file mode 100644 index 0000000000..4073ce641e --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/quad/BavetRecordingQuadConstraintStream.java @@ -0,0 +1,42 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.quad; + +import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory; +import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; + +public class BavetRecordingQuadConstraintStream + extends BavetAbstractQuadConstraintStream { + protected BavetRecordingQuadConstraintStream( + BavetConstraintFactory constraintFactory, + BavetAbstractConstraintStream parent) { + super(constraintFactory, parent); + } + + @Override + public > void buildNode(ConstraintNodeBuildHelper buildHelper) { + assertEmptyChildStreamList(); + buildHelper.putInsertUpdateRetract(this, TupleLifecycle.recording()); + } + + // ************************************************************************ + // Equality for node sharing + // ************************************************************************ + + @Override + public int hashCode() { + return parent.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (o instanceof BavetRecordingQuadConstraintStream other) { + return parent.equals(other.parent); + } else { + return false; + } + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/tri/BavetPrecomputeTriConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/tri/BavetPrecomputeTriConstraintStream.java new file mode 100644 index 0000000000..aaae2015ed --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/tri/BavetPrecomputeTriConstraintStream.java @@ -0,0 +1,72 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.tri; + +import java.util.Objects; +import java.util.Set; + +import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.TupleSource; +import ai.timefold.solver.core.impl.bavet.common.tuple.TriTuple; +import ai.timefold.solver.core.impl.bavet.tri.PrecomputeTriNode; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory; +import ai.timefold.solver.core.impl.score.stream.bavet.common.BavetPrecomputeBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.bridge.BavetAftBridgeTriConstraintStream; +import ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics; + +public class BavetPrecomputeTriConstraintStream extends BavetAbstractTriConstraintStream + implements TupleSource { + private final BavetAbstractConstraintStream recordingPrecomputedConstraintStream; + private BavetAftBridgeTriConstraintStream aftStream; + + public BavetPrecomputeTriConstraintStream( + BavetConstraintFactory constraintFactory, + BavetAbstractConstraintStream precomputedConstraintStream) { + super(constraintFactory, RetrievalSemantics.STANDARD); + this.recordingPrecomputedConstraintStream = new BavetRecordingTriConstraintStream<>(constraintFactory, + precomputedConstraintStream); + precomputedConstraintStream.getChildStreamList().add(recordingPrecomputedConstraintStream); + } + + public void setAftBridge(BavetAftBridgeTriConstraintStream aftStream) { + this.aftStream = aftStream; + } + + @Override + public > void buildNode(ConstraintNodeBuildHelper buildHelper) { + var precomputedBuildHelper = new BavetPrecomputeBuildHelper>(recordingPrecomputedConstraintStream); + var outputStoreSize = buildHelper.extractTupleStoreSize(aftStream); + + buildHelper.addNode(new PrecomputeTriNode<>(precomputedBuildHelper.getNodeNetwork(), + precomputedBuildHelper.getRecordingTupleLifecycle(), + outputStoreSize, + buildHelper.getAggregatedTupleLifecycle(aftStream.getChildStreamList()), + precomputedBuildHelper.getSourceClasses()), + this); + } + + @Override + public void collectActiveConstraintStreams(Set> constraintStreamSet) { + constraintStreamSet.add(this); + } + + // ************************************************************************ + // Equality for node sharing + // ************************************************************************ + + @Override + public int hashCode() { + return Objects.hash(recordingPrecomputedConstraintStream); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (o instanceof BavetPrecomputeTriConstraintStream other) { + return recordingPrecomputedConstraintStream.equals(other.recordingPrecomputedConstraintStream); + } else { + return false; + } + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/tri/BavetRecordingTriConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/tri/BavetRecordingTriConstraintStream.java new file mode 100644 index 0000000000..9ec3d45218 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/tri/BavetRecordingTriConstraintStream.java @@ -0,0 +1,42 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.tri; + +import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory; +import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; + +public class BavetRecordingTriConstraintStream + extends BavetAbstractTriConstraintStream { + protected BavetRecordingTriConstraintStream( + BavetConstraintFactory constraintFactory, + BavetAbstractConstraintStream parent) { + super(constraintFactory, parent); + } + + @Override + public > void buildNode(ConstraintNodeBuildHelper buildHelper) { + assertEmptyChildStreamList(); + buildHelper.putInsertUpdateRetract(this, TupleLifecycle.recording()); + } + + // ************************************************************************ + // Equality for node sharing + // ************************************************************************ + + @Override + public int hashCode() { + return parent.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (o instanceof BavetRecordingTriConstraintStream other) { + return parent.equals(other.parent); + } else { + return false; + } + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetForEachUniConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetForEachUniConstraintStream.java index f5b8cc1c7e..1253b46eb7 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetForEachUniConstraintStream.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetForEachUniConstraintStream.java @@ -77,12 +77,13 @@ public boolean equals(Object other) { return false; } BavetForEachUniConstraintStream that = (BavetForEachUniConstraintStream) other; - return Objects.equals(forEachClass, that.forEachClass) && Objects.equals(filterFunction, that.filterFunction); + return Objects.equals(forEachClass, that.forEachClass) && Objects.equals(filterFunction, that.filterFunction) + && getRetrievalSemantics().equals(that.getRetrievalSemantics()); } @Override public int hashCode() { - return Objects.hash(forEachClass, filterFunction); + return Objects.hash(forEachClass, filterFunction, getRetrievalSemantics()); } @Override diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetPrecomputeUniConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetPrecomputeUniConstraintStream.java new file mode 100644 index 0000000000..f9ce270054 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetPrecomputeUniConstraintStream.java @@ -0,0 +1,72 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.uni; + +import java.util.Objects; +import java.util.Set; + +import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.TupleSource; +import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; +import ai.timefold.solver.core.impl.bavet.uni.PrecomputeUniNode; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory; +import ai.timefold.solver.core.impl.score.stream.bavet.common.BavetPrecomputeBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; +import ai.timefold.solver.core.impl.score.stream.bavet.common.bridge.BavetAftBridgeUniConstraintStream; +import ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics; + +public class BavetPrecomputeUniConstraintStream extends BavetAbstractUniConstraintStream + implements TupleSource { + private final BavetAbstractConstraintStream recordingPrecomputedConstraintStream; + private BavetAftBridgeUniConstraintStream aftStream; + + public BavetPrecomputeUniConstraintStream( + BavetConstraintFactory constraintFactory, + BavetAbstractConstraintStream precomputedConstraintStream) { + super(constraintFactory, RetrievalSemantics.STANDARD); + this.recordingPrecomputedConstraintStream = new BavetRecordingUniConstraintStream<>(constraintFactory, + precomputedConstraintStream); + precomputedConstraintStream.getChildStreamList().add(recordingPrecomputedConstraintStream); + } + + public void setAftBridge(BavetAftBridgeUniConstraintStream aftStream) { + this.aftStream = aftStream; + } + + @Override + public > void buildNode(ConstraintNodeBuildHelper buildHelper) { + var precomputeBuildHelper = new BavetPrecomputeBuildHelper>(recordingPrecomputedConstraintStream); + var outputStoreSize = buildHelper.extractTupleStoreSize(aftStream); + + buildHelper.addNode(new PrecomputeUniNode<>(precomputeBuildHelper.getNodeNetwork(), + precomputeBuildHelper.getRecordingTupleLifecycle(), + outputStoreSize, + buildHelper.getAggregatedTupleLifecycle(aftStream.getChildStreamList()), + precomputeBuildHelper.getSourceClasses()), + this); + } + + @Override + public void collectActiveConstraintStreams(Set> constraintStreamSet) { + constraintStreamSet.add(this); + } + + // ************************************************************************ + // Equality for node sharing + // ************************************************************************ + + @Override + public int hashCode() { + return Objects.hash(recordingPrecomputedConstraintStream); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (o instanceof BavetPrecomputeUniConstraintStream other) { + return recordingPrecomputedConstraintStream.equals(other.recordingPrecomputedConstraintStream); + } else { + return false; + } + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetRecordingUniConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetRecordingUniConstraintStream.java new file mode 100644 index 0000000000..fb78d6fc41 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetRecordingUniConstraintStream.java @@ -0,0 +1,41 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.uni; + +import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory; +import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper; + +public class BavetRecordingUniConstraintStream extends BavetAbstractUniConstraintStream { + protected BavetRecordingUniConstraintStream( + BavetConstraintFactory constraintFactory, + BavetAbstractConstraintStream parent) { + super(constraintFactory, parent); + } + + @Override + public > void buildNode(ConstraintNodeBuildHelper buildHelper) { + assertEmptyChildStreamList(); + buildHelper.putInsertUpdateRetract(this, TupleLifecycle.recording()); + } + + // ************************************************************************ + // Equality for node sharing + // ************************************************************************ + + @Override + public int hashCode() { + return parent.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (o instanceof BavetRecordingUniConstraintStream other) { + return parent.equals(other.parent); + } else { + return false; + } + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/RetrievalSemantics.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/RetrievalSemantics.java index fa19cc252c..00b2b1ad55 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/RetrievalSemantics.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/RetrievalSemantics.java @@ -2,12 +2,13 @@ import ai.timefold.solver.core.api.domain.variable.PlanningVariable; import ai.timefold.solver.core.api.score.stream.ConstraintFactory; +import ai.timefold.solver.core.api.score.stream.PrecomputeFactory; /** * Determines the behavior of joins and conditional propagation * based on whether they are coming off of a constraint stream started by - * either {@link ConstraintFactory#from(Class)} - * or {@link ConstraintFactory#forEach(Class)} + * either {@link ConstraintFactory#from(Class)}, {@link ConstraintFactory#forEach(Class)}, + * or {@link PrecomputeFactory#forEachUnfiltered(Class)} * family of methods. * *

@@ -29,6 +30,15 @@ public enum RetrievalSemantics { * Applies when the stream comes off of a {@link ConstraintFactory#forEach(Class)} family of methods. */ STANDARD, + + /** + * Joins and conditional propagation always include entities with null planning variables, + * regardless of whether their planning variables allow unassigned values. + *

+ * Applies when the stream comes off of a {@link PrecomputeFactory#forEachUnfiltered(Class)} family of methods. + */ + PRECOMPUTE, + /** * Joins include entities with null planning variables if these variables allow unassigned values. * Conditional propagation always includes entities with null planning variables, diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/bi/InnerBiConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/bi/InnerBiConstraintStream.java index e3125ff6d3..e4386aac19 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/bi/InnerBiConstraintStream.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/bi/InnerBiConstraintStream.java @@ -1,7 +1,5 @@ package ai.timefold.solver.core.impl.score.stream.common.bi; -import static ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics.STANDARD; - import java.math.BigDecimal; import java.util.Arrays; import java.util.Collection; @@ -45,53 +43,55 @@ static BiFunction> createDefaultIndictedObjectsMappin @Override default @NonNull TriConstraintStream join(@NonNull Class otherClass, TriJoiner... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return join(getConstraintFactory().forEach(otherClass), joiners); - } else { - return join(getConstraintFactory().from(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> join(getConstraintFactory().forEach(otherClass), joiners); + case PRECOMPUTE -> join(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + case LEGACY -> join(getConstraintFactory().from(otherClass), joiners); + }; } @Override default @NonNull BiConstraintStream ifExists(@NonNull Class otherClass, TriJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifExists(getConstraintFactory().forEach(otherClass), joiners); - } else { + return switch (getRetrievalSemantics()) { + case STANDARD -> ifExists(getConstraintFactory().forEach(otherClass), joiners); + case PRECOMPUTE -> ifExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); // Calls fromUnfiltered() for backward compatibility only - return ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + case LEGACY -> ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull BiConstraintStream ifExistsIncludingUnassigned(@NonNull Class otherClass, @NonNull TriJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); - } else { - return ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> ifExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); + case PRECOMPUTE -> ifExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + // Calls fromUnfiltered() for backward compatibility only + case LEGACY -> ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull BiConstraintStream ifNotExists(@NonNull Class otherClass, @NonNull TriJoiner... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifNotExists(getConstraintFactory().forEach(otherClass), joiners); - } else { + return switch (getRetrievalSemantics()) { + case STANDARD -> ifNotExists(getConstraintFactory().forEach(otherClass), joiners); + case PRECOMPUTE -> ifNotExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); // Calls fromUnfiltered() for backward compatibility only - return ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + case LEGACY -> ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull BiConstraintStream ifNotExistsIncludingUnassigned(@NonNull Class otherClass, @NonNull TriJoiner... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifNotExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); - } else { - return ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> ifNotExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); + case PRECOMPUTE -> ifNotExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + // Calls fromUnfiltered() for backward compatibility only + case LEGACY -> ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/quad/InnerQuadConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/quad/InnerQuadConstraintStream.java index 51911c1737..99c89e14cf 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/quad/InnerQuadConstraintStream.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/quad/InnerQuadConstraintStream.java @@ -44,43 +44,45 @@ static QuadFunction> createDefaultIndicte @Override default @NonNull QuadConstraintStream ifExists(@NonNull Class otherClass, @NonNull PentaJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == RetrievalSemantics.STANDARD) { - return ifExists(getConstraintFactory().forEach(otherClass), joiners); - } else { + return switch (getRetrievalSemantics()) { + case STANDARD -> ifExists(getConstraintFactory().forEach(otherClass), joiners); + case PRECOMPUTE -> ifExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); // Calls fromUnfiltered() for backward compatibility only - return ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + case LEGACY -> ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull QuadConstraintStream ifExistsIncludingUnassigned(@NonNull Class otherClass, @NonNull PentaJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == RetrievalSemantics.STANDARD) { - return ifExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); - } else { - return ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> ifExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); + case PRECOMPUTE -> ifExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + // Calls fromUnfiltered() for backward compatibility only + case LEGACY -> ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull QuadConstraintStream ifNotExists(@NonNull Class otherClass, @NonNull PentaJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == RetrievalSemantics.STANDARD) { - return ifNotExists(getConstraintFactory().forEach(otherClass), joiners); - } else { + return switch (getRetrievalSemantics()) { + case STANDARD -> ifNotExists(getConstraintFactory().forEach(otherClass), joiners); + case PRECOMPUTE -> ifNotExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); // Calls fromUnfiltered() for backward compatibility only - return ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + case LEGACY -> ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull QuadConstraintStream ifNotExistsIncludingUnassigned(@NonNull Class otherClass, @NonNull PentaJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == RetrievalSemantics.STANDARD) { - return ifNotExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); - } else { - return ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> ifNotExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); + case PRECOMPUTE -> ifNotExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + // Calls fromUnfiltered() for backward compatibility only + case LEGACY -> ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/tri/InnerTriConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/tri/InnerTriConstraintStream.java index 6d975c6cf5..20ab33db45 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/tri/InnerTriConstraintStream.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/tri/InnerTriConstraintStream.java @@ -1,7 +1,5 @@ package ai.timefold.solver.core.impl.score.stream.common.tri; -import static ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics.STANDARD; - import java.math.BigDecimal; import java.util.Arrays; import java.util.Collection; @@ -46,53 +44,55 @@ static TriFunction> createDefaultIndictedObject @Override default @NonNull QuadConstraintStream join(@NonNull Class otherClass, @NonNull QuadJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return join(getConstraintFactory().forEach(otherClass), joiners); - } else { - return join(getConstraintFactory().from(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> join(getConstraintFactory().forEach(otherClass), joiners); + case PRECOMPUTE -> join(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + case LEGACY -> join(getConstraintFactory().from(otherClass), joiners); + }; } @Override default @NonNull TriConstraintStream ifExists(@NonNull Class otherClass, @NonNull QuadJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifExists(getConstraintFactory().forEach(otherClass), joiners); - } else { + return switch (getRetrievalSemantics()) { + case STANDARD -> ifExists(getConstraintFactory().forEach(otherClass), joiners); + case PRECOMPUTE -> ifExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); // Calls fromUnfiltered() for backward compatibility only - return ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + case LEGACY -> ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull TriConstraintStream ifExistsIncludingUnassigned(@NonNull Class otherClass, @NonNull QuadJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); - } else { - return ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> ifExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); + case PRECOMPUTE -> ifExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + // Calls fromUnfiltered() for backward compatibility only + case LEGACY -> ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull TriConstraintStream ifNotExists(@NonNull Class otherClass, @NonNull QuadJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifNotExists(getConstraintFactory().forEach(otherClass), joiners); - } else { + return switch (getRetrievalSemantics()) { + case STANDARD -> ifNotExists(getConstraintFactory().forEach(otherClass), joiners); + case PRECOMPUTE -> ifNotExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); // Calls fromUnfiltered() for backward compatibility only - return ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + case LEGACY -> ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull TriConstraintStream ifNotExistsIncludingUnassigned(@NonNull Class otherClass, @NonNull QuadJoiner @NonNull... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifNotExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); - } else { - return ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> ifNotExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); + case PRECOMPUTE -> ifNotExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + // Calls fromUnfiltered() for backward compatibility only + case LEGACY -> ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/uni/InnerUniConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/uni/InnerUniConstraintStream.java index 65efad8fe9..dc7b62e7b0 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/uni/InnerUniConstraintStream.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/common/uni/InnerUniConstraintStream.java @@ -1,7 +1,5 @@ package ai.timefold.solver.core.impl.score.stream.common.uni; -import static ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics.STANDARD; - import java.math.BigDecimal; import java.util.Collection; import java.util.Collections; @@ -47,11 +45,11 @@ static Function> createDefaultIndictedObjectsMapping() { @Override default @NonNull BiConstraintStream join(@NonNull Class otherClass, @NonNull BiJoiner... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return join(getConstraintFactory().forEach(otherClass), joiners); - } else { - return join(getConstraintFactory().from(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> join(getConstraintFactory().forEach(otherClass), joiners); + case PRECOMPUTE -> join(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + case LEGACY -> join(getConstraintFactory().from(otherClass), joiners); + }; } /** @@ -66,42 +64,44 @@ static Function> createDefaultIndictedObjectsMapping() { @Override default @NonNull UniConstraintStream ifExists(@NonNull Class otherClass, @NonNull BiJoiner... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifExists(getConstraintFactory().forEach(otherClass), joiners); - } else { + return switch (getRetrievalSemantics()) { + case STANDARD -> ifExists(getConstraintFactory().forEach(otherClass), joiners); + case PRECOMPUTE -> ifExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); // Calls fromUnfiltered() for backward compatibility only - return ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + case LEGACY -> ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull UniConstraintStream ifExistsIncludingUnassigned(@NonNull Class otherClass, @NonNull BiJoiner... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); - } else { - return ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> ifExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); + case PRECOMPUTE -> ifExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + // Calls fromUnfiltered() for backward compatibility only + case LEGACY -> ifExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull UniConstraintStream ifNotExists(@NonNull Class otherClass, @NonNull BiJoiner... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifNotExists(getConstraintFactory().forEach(otherClass), joiners); - } else { + return switch (getRetrievalSemantics()) { + case STANDARD -> ifNotExists(getConstraintFactory().forEach(otherClass), joiners); + case PRECOMPUTE -> ifNotExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); // Calls fromUnfiltered() for backward compatibility only - return ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + case LEGACY -> ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override default @NonNull UniConstraintStream ifNotExistsIncludingUnassigned(@NonNull Class otherClass, @NonNull BiJoiner... joiners) { - if (getRetrievalSemantics() == STANDARD) { - return ifNotExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); - } else { - return ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); - } + return switch (getRetrievalSemantics()) { + case STANDARD -> ifNotExists(getConstraintFactory().forEachIncludingUnassigned(otherClass), joiners); + case PRECOMPUTE -> ifNotExists(getConstraintFactory().forEachUnfiltered(otherClass), joiners); + // Calls fromUnfiltered() for backward compatibility only + case LEGACY -> ifNotExists(getConstraintFactory().fromUnfiltered(otherClass), joiners); + }; } @Override diff --git a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/bavet/bi/BavetBiConstraintStreamPrecomputeTest.java b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/bavet/bi/BavetBiConstraintStreamPrecomputeTest.java new file mode 100644 index 0000000000..71ea87ccd3 --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/bavet/bi/BavetBiConstraintStreamPrecomputeTest.java @@ -0,0 +1,13 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.bi; + +import ai.timefold.solver.core.impl.score.constraint.ConstraintMatchPolicy; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintStreamImplSupport; +import ai.timefold.solver.core.impl.score.stream.common.bi.AbstractBiConstraintStreamPrecomputeTest; + +final class BavetBiConstraintStreamPrecomputeTest extends AbstractBiConstraintStreamPrecomputeTest { + + public BavetBiConstraintStreamPrecomputeTest(ConstraintMatchPolicy constraintMatchPolicy) { + super(new BavetConstraintStreamImplSupport(constraintMatchPolicy)); + } + +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/bavet/quad/BavetQuadConstraintStreamPrecomputeTest.java b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/bavet/quad/BavetQuadConstraintStreamPrecomputeTest.java new file mode 100644 index 0000000000..d0996c263d --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/bavet/quad/BavetQuadConstraintStreamPrecomputeTest.java @@ -0,0 +1,13 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.quad; + +import ai.timefold.solver.core.impl.score.constraint.ConstraintMatchPolicy; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintStreamImplSupport; +import ai.timefold.solver.core.impl.score.stream.common.quad.AbstractQuadConstraintStreamPrecomputeTest; + +final class BavetQuadConstraintStreamPrecomputeTest extends AbstractQuadConstraintStreamPrecomputeTest { + + public BavetQuadConstraintStreamPrecomputeTest(ConstraintMatchPolicy constraintMatchPolicy) { + super(new BavetConstraintStreamImplSupport(constraintMatchPolicy)); + } + +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/bavet/tri/BavetTriConstraintStreamPrecomputeTest.java b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/bavet/tri/BavetTriConstraintStreamPrecomputeTest.java new file mode 100644 index 0000000000..131cb2b68c --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/bavet/tri/BavetTriConstraintStreamPrecomputeTest.java @@ -0,0 +1,13 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.tri; + +import ai.timefold.solver.core.impl.score.constraint.ConstraintMatchPolicy; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintStreamImplSupport; +import ai.timefold.solver.core.impl.score.stream.common.tri.AbstractTriConstraintStreamPrecomputeTest; + +final class BavetTriConstraintStreamPrecomputeTest extends AbstractTriConstraintStreamPrecomputeTest { + + public BavetTriConstraintStreamPrecomputeTest(ConstraintMatchPolicy constraintMatchPolicy) { + super(new BavetConstraintStreamImplSupport(constraintMatchPolicy)); + } + +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetUniConstraintStreamPrecomputeTest.java b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetUniConstraintStreamPrecomputeTest.java new file mode 100644 index 0000000000..f96fffaf63 --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetUniConstraintStreamPrecomputeTest.java @@ -0,0 +1,13 @@ +package ai.timefold.solver.core.impl.score.stream.bavet.uni; + +import ai.timefold.solver.core.impl.score.constraint.ConstraintMatchPolicy; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintStreamImplSupport; +import ai.timefold.solver.core.impl.score.stream.common.uni.AbstractUniConstraintStreamPrecomputeTest; + +final class BavetUniConstraintStreamPrecomputeTest extends AbstractUniConstraintStreamPrecomputeTest { + + public BavetUniConstraintStreamPrecomputeTest(ConstraintMatchPolicy constraintMatchPolicy) { + super(new BavetConstraintStreamImplSupport(constraintMatchPolicy)); + } + +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/ConstraintStreamNodeSharingTest.java b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/ConstraintStreamNodeSharingTest.java index 501d247b4e..d7ddbfce91 100644 --- a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/ConstraintStreamNodeSharingTest.java +++ b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/ConstraintStreamNodeSharingTest.java @@ -210,4 +210,8 @@ default void sameParentSameFunctionExpand() { void differentSecondSourceConcat(); void sameSourcesConcat(); + + void sameDataPrecompute(); + + void differentDataPrecompute(); } diff --git a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/ConstraintStreamPrecomputeTest.java b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/ConstraintStreamPrecomputeTest.java new file mode 100644 index 0000000000..0f5471e922 --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/ConstraintStreamPrecomputeTest.java @@ -0,0 +1,33 @@ +package ai.timefold.solver.core.impl.score.stream.common; + +public interface ConstraintStreamPrecomputeTest { + void filter_0_changed(); + + default void filter_1_changed() { + // requires two elements, so Bi, Tri and Quad + } + + default void filter_2_changed() { + // requires three elements, so Tri and Quad + } + + default void filter_3_changed() { + // requires four elements, Quad + } + + void ifExists(); + + void ifNotExists(); + + void groupBy(); + + void flattenLast(); + + void map(); + + void concat(); + + void distinct(); + + void complement(); +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamNodeSharingTest.java b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamNodeSharingTest.java index ea84df7d9c..a1c1fb2e91 100644 --- a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamNodeSharingTest.java +++ b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamNodeSharingTest.java @@ -21,6 +21,7 @@ import ai.timefold.solver.core.testdomain.TestdataSolution; import ai.timefold.solver.core.testdomain.TestdataValue; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.TestTemplate; @@ -544,4 +545,34 @@ public void sameSourcesConcat() { .concat(baseStream.filter(filter1))) .isSameAs(baseStream.concat(baseStream.filter(filter1))); } + + @Override + @TestTemplate + public void sameDataPrecompute() { + BiPredicate filter1 = (a, b) -> true; + Assertions.assertThat((BiConstraintStream) constraintFactory.precompute( + precomputeFactory -> precomputeFactory.forEachUnfiltered(TestdataEntity.class) + .join(TestdataEntity.class) + .filter(filter1))) + .isSameAs(constraintFactory.precompute( + precomputeFactory -> precomputeFactory.forEachUnfiltered(TestdataEntity.class) + .join(TestdataEntity.class) + .filter(filter1))); + } + + @Override + @TestTemplate + public void differentDataPrecompute() { + BiPredicate filter1 = (a, b) -> true; + BiPredicate filter2 = (a, b) -> false; + + Assertions.assertThat((BiConstraintStream) constraintFactory.precompute( + precomputeFactory -> precomputeFactory.forEachUnfiltered(TestdataEntity.class) + .join(TestdataEntity.class) + .filter(filter1))) + .isNotSameAs(constraintFactory.precompute( + precomputeFactory -> precomputeFactory.forEachUnfiltered(TestdataEntity.class) + .join(TestdataEntity.class) + .filter(filter2))); + } } diff --git a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamPrecomputeTest.java b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamPrecomputeTest.java new file mode 100644 index 0000000000..494b37aa18 --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamPrecomputeTest.java @@ -0,0 +1,395 @@ +package ai.timefold.solver.core.impl.score.stream.common.bi; + +import java.util.List; +import java.util.function.Function; + +import ai.timefold.solver.core.api.score.buildin.simple.SimpleScore; +import ai.timefold.solver.core.api.score.stream.ConstraintCollectors; +import ai.timefold.solver.core.api.score.stream.Joiners; +import ai.timefold.solver.core.api.score.stream.PrecomputeFactory; +import ai.timefold.solver.core.api.score.stream.bi.BiConstraintStream; +import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStreamTest; +import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamImplSupport; +import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamPrecomputeTest; +import ai.timefold.solver.core.impl.util.Pair; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishEntity; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishEntityGroup; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishSolution; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishValue; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishValueGroup; + +import org.junit.jupiter.api.TestTemplate; +import org.mockito.Mockito; + +public abstract class AbstractBiConstraintStreamPrecomputeTest extends AbstractConstraintStreamTest + implements ConstraintStreamPrecomputeTest { + protected AbstractBiConstraintStreamPrecomputeTest(ConstraintStreamImplSupport implSupport) { + super(implSupport); + } + + @Override + @TestTemplate + public void filter_0_changed() { + var solution = TestdataLavishSolution.generateSolution(); + var entityGroup = new TestdataLavishEntityGroup("MyEntityGroup"); + var valueGroup = new TestdataLavishValueGroup("MyValueGroup"); + solution.getEntityGroupList().add(entityGroup); + solution.getValueGroupList().add(valueGroup); + + var value1 = Mockito.spy(new TestdataLavishValue("MyValue 1", valueGroup)); + solution.getValueList().add(value1); + var value2 = Mockito.spy(new TestdataLavishValue("MyValue 2", valueGroup)); + solution.getValueList().add(value2); + var value3 = Mockito.spy(new TestdataLavishValue("MyValue 3", null)); + solution.getValueList().add(value3); + + var entity1 = Mockito.spy(new TestdataLavishEntity("MyEntity 1", entityGroup, value1)); + solution.getEntityList().add(entity1); + var entity2 = new TestdataLavishEntity("MyEntity 2", entityGroup, value1); + solution.getEntityList().add(entity2); + var entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value1); + solution.getEntityList().add(entity3); + + var scoreDirector = + buildScoreDirector(factory -> factory.precompute(data -> data.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .filter((entity, value) -> entity.getEntityGroup() == entityGroup + && value.getValueGroup() == valueGroup)) + .filter((entity, value) -> entity.getValue() == value1) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + Mockito.reset(entity1); + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1, value1), + assertMatch(entity1, value2), + assertMatch(entity2, value1), + assertMatch(entity2, value2)); + Mockito.verify(entity1, Mockito.atLeastOnce()).getEntityGroup(); + + // Incrementally update a variable + Mockito.reset(entity1); + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(solution.getFirstValue()); + scoreDirector.afterVariableChanged(entity1, "value"); + assertScore(scoreDirector, + assertMatch(entity2, value1), + assertMatch(entity2, value2)); + Mockito.verify(entity1, Mockito.never()).getEntityGroup(); + + // Incrementally update a fact + scoreDirector.beforeProblemPropertyChanged(entity3); + entity3.setEntityGroup(entityGroup); + scoreDirector.afterProblemPropertyChanged(entity3); + assertScore(scoreDirector, + assertMatch(entity2, value1), + assertMatch(entity2, value2), + assertMatch(entity3, value1), + assertMatch(entity3, value2)); + + // Remove entity + scoreDirector.beforeEntityRemoved(entity3); + solution.getEntityList().remove(entity3); + scoreDirector.afterEntityRemoved(entity3); + assertScore(scoreDirector, + assertMatch(entity2, value1), + assertMatch(entity2, value2)); + + // Add it back again, to make sure it was properly removed before + scoreDirector.beforeEntityAdded(entity3); + solution.getEntityList().add(entity3); + scoreDirector.afterEntityAdded(entity3); + assertScore(scoreDirector, + assertMatch(entity2, value1), + assertMatch(entity2, value2), + assertMatch(entity3, value1), + assertMatch(entity3, value2)); + } + + @Override + @TestTemplate + public void filter_1_changed() { + var solution = TestdataLavishSolution.generateSolution(); + var entityGroup = new TestdataLavishEntityGroup("MyEntityGroup"); + var valueGroup = new TestdataLavishValueGroup("MyValueGroup"); + solution.getEntityGroupList().add(entityGroup); + solution.getValueGroupList().add(valueGroup); + + var value1 = Mockito.spy(new TestdataLavishValue("MyValue 1", valueGroup)); + solution.getValueList().add(value1); + var value2 = Mockito.spy(new TestdataLavishValue("MyValue 2", valueGroup)); + solution.getValueList().add(value2); + var value3 = Mockito.spy(new TestdataLavishValue("MyValue 3", null)); + solution.getValueList().add(value3); + + var entity1 = Mockito.spy(new TestdataLavishEntity("MyEntity 1", entityGroup, value1)); + solution.getEntityList().add(entity1); + var entity2 = new TestdataLavishEntity("MyEntity 2", entityGroup, value1); + solution.getEntityList().add(entity2); + var entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value1); + solution.getEntityList().add(entity3); + + var scoreDirector = + buildScoreDirector(factory -> factory.precompute(data -> data.forEachUnfiltered(TestdataLavishValue.class) + .join(TestdataLavishEntity.class) + .filter((value, entity) -> entity.getEntityGroup() == entityGroup + && value.getValueGroup() == valueGroup)) + .filter((value, entity) -> entity.getValue() == value1) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + Mockito.reset(entity1); + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(value1, entity1), + assertMatch(value2, entity1), + assertMatch(value1, entity2), + assertMatch(value2, entity2)); + Mockito.verify(entity1, Mockito.atLeastOnce()).getEntityGroup(); + + // Incrementally update a variable + Mockito.reset(entity1); + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(solution.getFirstValue()); + scoreDirector.afterVariableChanged(entity1, "value"); + assertScore(scoreDirector, + assertMatch(value1, entity2), + assertMatch(value2, entity2)); + Mockito.verify(entity1, Mockito.never()).getEntityGroup(); + + // Incrementally update a fact + scoreDirector.beforeProblemPropertyChanged(entity3); + entity3.setEntityGroup(entityGroup); + scoreDirector.afterProblemPropertyChanged(entity3); + assertScore(scoreDirector, + assertMatch(value1, entity2), + assertMatch(value2, entity2), + assertMatch(value1, entity3), + assertMatch(value2, entity3)); + + // Remove entity + scoreDirector.beforeEntityRemoved(entity3); + solution.getEntityList().remove(entity3); + scoreDirector.afterEntityRemoved(entity3); + assertScore(scoreDirector, + assertMatch(value1, entity2), + assertMatch(value2, entity2)); + + // Add it back again, to make sure it was properly removed before + scoreDirector.beforeEntityAdded(entity3); + solution.getEntityList().add(entity3); + scoreDirector.afterEntityAdded(entity3); + assertScore(scoreDirector, + assertMatch(value1, entity2), + assertMatch(value2, entity2), + assertMatch(value1, entity3), + assertMatch(value2, entity3)); + } + + private void assertPrecompute(TestdataLavishSolution solution, + List> expectedValues, + Function> entityStreamSupplier) { + var scoreDirector = + buildScoreDirector(factory -> factory.precompute(entityStreamSupplier) + .ifExists(TestdataLavishEntity.class) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector); + + for (var entity : solution.getEntityList()) { + scoreDirector.beforeVariableChanged(entity, "value"); + entity.setValue(solution.getFirstValue()); + scoreDirector.afterVariableChanged(entity, "value"); + } + + assertScore(scoreDirector, expectedValues.stream() + .map(pair -> new Object[] { pair.key(), pair.value() }) + .map(AbstractConstraintStreamTest::assertMatch) + .toArray(AssertableMatch[]::new)); + } + + @Override + @TestTemplate + public void ifExists() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Pair<>(entityWithGroup, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .ifExists(TestdataLavishEntityGroup.class, Joiners.equal( + (a, b) -> a.getEntityGroup(), Function.identity()))); + } + + @Override + @TestTemplate + public void ifNotExists() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Pair<>(entityWithoutGroup, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .ifNotExists(TestdataLavishEntityGroup.class, Joiners.equal( + (a, b) -> a.getEntityGroup(), Function.identity()))); + } + + @Override + @TestTemplate + public void groupBy() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Pair<>(entityGroup, 1)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .filter(entity -> entity.getEntityGroup() != null) + .groupBy(TestdataLavishEntity::getEntityGroup, ConstraintCollectors.count())); + } + + @Override + @TestTemplate + public void flattenLast() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Pair<>(entityWithoutGroup, value), + new Pair<>(entityWithGroup, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .groupBy(ConstraintCollectors.toList()) + .flattenLast(entityList -> entityList) + .join(TestdataLavishValue.class)); + } + + @Override + @TestTemplate + public void map() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup1 = new TestdataLavishEntity(); + var entityWithGroup2 = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup1.setEntityGroup(entityGroup); + entityWithGroup2.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Pair<>(entityGroup, value), + new Pair<>(entityGroup, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .filter((entity, joinedValue) -> entity.getEntityGroup() != null) + .map((entity, joinedValue) -> entity.getEntityGroup(), + (entity, joinedValue) -> joinedValue)); + } + + @Override + @TestTemplate + public void concat() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Pair<>(entityWithoutGroup, value), new Pair<>(entityWithGroup, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .filter((entity, joinedValue) -> entity.getEntityGroup() == null) + .concat(pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .filter((entity, joinedValue) -> entity.getEntityGroup() != null))); + } + + @Override + @TestTemplate + public void distinct() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup1 = new TestdataLavishEntity(); + var entityWithGroup2 = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup1.setEntityGroup(entityGroup); + entityWithGroup2.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Pair<>(entityGroup, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .filter((entity, joinedValue) -> entity.getEntityGroup() != null) + .map((entity, joinedValue) -> entity.getEntityGroup(), + (entity, joinedValue) -> joinedValue) + .distinct()); + } + + @Override + @TestTemplate + public void complement() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup1 = new TestdataLavishEntity(); + var entityWithGroup2 = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup1.setEntityGroup(entityGroup); + entityWithGroup2.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of( + new Pair<>(entityWithGroup1, value), + new Pair<>(entityWithGroup2, value), + new Pair<>(entityWithoutGroup, null)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .filter((entity, joinedValue) -> entity.getEntityGroup() != null) + .complement(TestdataLavishEntity.class)); + } +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamTest.java b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamTest.java index 0549b286ce..c7b4dfef1f 100644 --- a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamTest.java +++ b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/bi/AbstractBiConstraintStreamTest.java @@ -38,6 +38,7 @@ import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStreamTest; import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamFunctionalTest; import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamImplSupport; +import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamTestExtension; import ai.timefold.solver.core.testdomain.TestdataEntity; import ai.timefold.solver.core.testdomain.list.unassignedvar.TestdataAllowsUnassignedValuesListEntity; import ai.timefold.solver.core.testdomain.list.unassignedvar.TestdataAllowsUnassignedValuesListSolution; @@ -52,7 +53,12 @@ import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishValueGroup; import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.api.parallel.ExecutionMode; +@ExtendWith(ConstraintStreamTestExtension.class) +@Execution(ExecutionMode.CONCURRENT) public abstract class AbstractBiConstraintStreamTest extends AbstractConstraintStreamTest implements ConstraintStreamFunctionalTest { @@ -3305,5 +3311,4 @@ public void joinerEqualsAndSameness() { assertMatch(entity3, entity1), assertMatch(entity3, entity2)); } - } diff --git a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/quad/AbstractQuadConstraintStreamNodeSharingTest.java b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/quad/AbstractQuadConstraintStreamNodeSharingTest.java index eaec6ff672..6ba46af78d 100644 --- a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/quad/AbstractQuadConstraintStreamNodeSharingTest.java +++ b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/quad/AbstractQuadConstraintStreamNodeSharingTest.java @@ -19,6 +19,7 @@ import ai.timefold.solver.core.testdomain.TestdataEntity; import ai.timefold.solver.core.testdomain.TestdataSolution; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.TestTemplate; @@ -476,4 +477,42 @@ public void sameSourcesConcat() { .concat(baseStream.filter(filter1))) .isSameAs(baseStream.concat(baseStream.filter(filter1))); } + + @Override + @TestTemplate + public void sameDataPrecompute() { + QuadPredicate filter1 = (a, b, c, d) -> true; + Assertions.assertThat((QuadConstraintStream) constraintFactory.precompute( + precomputeFactory -> precomputeFactory.forEachUnfiltered(TestdataEntity.class) + .join(TestdataEntity.class) + .join(TestdataEntity.class) + .join(TestdataEntity.class) + .filter(filter1))) + .isSameAs(constraintFactory.precompute( + precomputeFactory -> precomputeFactory.forEachUnfiltered(TestdataEntity.class) + .join(TestdataEntity.class) + .join(TestdataEntity.class) + .join(TestdataEntity.class) + .filter(filter1))); + } + + @Override + @TestTemplate + public void differentDataPrecompute() { + QuadPredicate filter1 = (a, b, c, d) -> true; + QuadPredicate filter2 = (a, b, c, d) -> false; + + Assertions.assertThat((QuadConstraintStream) constraintFactory.precompute( + precomputeFactory -> precomputeFactory.forEachUnfiltered(TestdataEntity.class) + .join(TestdataEntity.class) + .join(TestdataEntity.class) + .join(TestdataEntity.class) + .filter(filter1))) + .isNotSameAs(constraintFactory.precompute( + precomputeFactory -> precomputeFactory.forEachUnfiltered(TestdataEntity.class) + .join(TestdataEntity.class) + .join(TestdataEntity.class) + .join(TestdataEntity.class) + .filter(filter2))); + } } diff --git a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/quad/AbstractQuadConstraintStreamPrecomputeTest.java b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/quad/AbstractQuadConstraintStreamPrecomputeTest.java new file mode 100644 index 0000000000..c8a3a5766a --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/quad/AbstractQuadConstraintStreamPrecomputeTest.java @@ -0,0 +1,416 @@ +package ai.timefold.solver.core.impl.score.stream.common.quad; + +import java.util.List; +import java.util.function.Function; + +import ai.timefold.solver.core.api.function.QuadFunction; +import ai.timefold.solver.core.api.function.TriFunction; +import ai.timefold.solver.core.api.score.buildin.simple.SimpleScore; +import ai.timefold.solver.core.api.score.stream.ConstraintCollectors; +import ai.timefold.solver.core.api.score.stream.Joiners; +import ai.timefold.solver.core.api.score.stream.PrecomputeFactory; +import ai.timefold.solver.core.api.score.stream.quad.QuadConstraintStream; +import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStreamTest; +import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamImplSupport; +import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamPrecomputeTest; +import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamTestExtension; +import ai.timefold.solver.core.impl.util.Quadruple; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishEntity; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishEntityGroup; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishSolution; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishValue; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishValueGroup; + +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.api.parallel.ExecutionMode; +import org.mockito.Mockito; + +@ExtendWith(ConstraintStreamTestExtension.class) +@Execution(ExecutionMode.CONCURRENT) +public abstract class AbstractQuadConstraintStreamPrecomputeTest extends AbstractConstraintStreamTest + implements ConstraintStreamPrecomputeTest { + protected AbstractQuadConstraintStreamPrecomputeTest(ConstraintStreamImplSupport implSupport) { + super(implSupport); + } + + private void assertPrecomputeFilterChanged( + TriFunction> precomputeStream, + QuadFunction entityPicker, + QuadFunction> inputDataToTuple) { + var solution = TestdataLavishSolution.generateSolution(); + var entityGroup = new TestdataLavishEntityGroup("MyEntityGroup"); + var valueGroup = new TestdataLavishValueGroup("MyValueGroup"); + solution.getEntityGroupList().add(entityGroup); + solution.getValueGroupList().add(valueGroup); + + var value1 = Mockito.spy(new TestdataLavishValue("MyValue 1", valueGroup)); + solution.getValueList().add(value1); + var value2 = Mockito.spy(new TestdataLavishValue("MyValue 2", valueGroup)); + solution.getValueList().add(value2); + var value3 = Mockito.spy(new TestdataLavishValue("MyValue 3", null)); + solution.getValueList().add(value3); + + var entity1 = Mockito.spy(new TestdataLavishEntity("MyEntity 1", entityGroup, value1)); + solution.getEntityList().add(entity1); + var entity2 = new TestdataLavishEntity("MyEntity 2", entityGroup, value1); + solution.getEntityList().add(entity2); + var entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value1); + solution.getEntityList().add(entity3); + + var scoreDirector = + buildScoreDirector(factory -> factory.precompute(pf -> precomputeStream.apply(pf, entityGroup, valueGroup)) + .filter((a, b, c, d) -> entityPicker.apply(a, b, c, d).getValue() == value1) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + var createMatch = + (QuadFunction) ( + entity, + value, matchEntityGroup, matchValueGroup) -> { + var tuple = inputDataToTuple.apply(entity, value, matchEntityGroup, matchValueGroup); + return assertMatch(tuple.a(), tuple.b(), tuple.c(), tuple.d()); + }; + Mockito.reset(entity1); + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + createMatch.apply(entity1, value1, entityGroup, valueGroup), + createMatch.apply(entity1, value2, entityGroup, valueGroup), + createMatch.apply(entity2, value1, entityGroup, valueGroup), + createMatch.apply(entity2, value2, entityGroup, valueGroup)); + Mockito.verify(entity1, Mockito.atLeastOnce()).getEntityGroup(); + + // Incrementally update a variable + Mockito.reset(entity1); + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(solution.getFirstValue()); + scoreDirector.afterVariableChanged(entity1, "value"); + assertScore(scoreDirector, + createMatch.apply(entity2, value1, entityGroup, valueGroup), + createMatch.apply(entity2, value2, entityGroup, valueGroup)); + Mockito.verify(entity1, Mockito.never()).getEntityGroup(); + + // Incrementally update a fact + scoreDirector.beforeProblemPropertyChanged(entity3); + entity3.setEntityGroup(entityGroup); + scoreDirector.afterProblemPropertyChanged(entity3); + assertScore(scoreDirector, + createMatch.apply(entity2, value1, entityGroup, valueGroup), + createMatch.apply(entity2, value2, entityGroup, valueGroup), + createMatch.apply(entity3, value1, entityGroup, valueGroup), + createMatch.apply(entity3, value2, entityGroup, valueGroup)); + + // Remove entity + scoreDirector.beforeEntityRemoved(entity3); + solution.getEntityList().remove(entity3); + scoreDirector.afterEntityRemoved(entity3); + assertScore(scoreDirector, + createMatch.apply(entity2, value1, entityGroup, valueGroup), + createMatch.apply(entity2, value2, entityGroup, valueGroup)); + + // Add it back again, to make sure it was properly removed before + scoreDirector.beforeEntityAdded(entity3); + solution.getEntityList().add(entity3); + scoreDirector.afterEntityAdded(entity3); + assertScore(scoreDirector, + createMatch.apply(entity2, value1, entityGroup, valueGroup), + createMatch.apply(entity2, value2, entityGroup, valueGroup), + createMatch.apply(entity3, value1, entityGroup, valueGroup), + createMatch.apply(entity3, value2, entityGroup, valueGroup)); + } + + @Override + @TestTemplate + public void filter_0_changed() { + assertPrecomputeFilterChanged( + (precomputeFactory, entityGroup, valueGroup) -> precomputeFactory.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishEntityGroup.class) + .join(TestdataLavishValueGroup.class) + .filter((entity, value, matchedEntityGroup, matchedValueGroup) -> entity.getEntityGroup() == entityGroup + && value.getValueGroup() == valueGroup + && matchedEntityGroup == entityGroup + && matchedValueGroup == valueGroup), + (entity, value, entityGroup, valueGroup) -> entity, + Quadruple::new); + } + + @Override + @TestTemplate + public void filter_1_changed() { + assertPrecomputeFilterChanged( + (precomputeFactory, entityGroup, valueGroup) -> precomputeFactory.forEachUnfiltered(TestdataLavishValue.class) + .join(TestdataLavishEntity.class) + .join(TestdataLavishEntityGroup.class) + .join(TestdataLavishValueGroup.class) + .filter((value, entity, matchedEntityGroup, matchedValueGroup) -> entity.getEntityGroup() == entityGroup + && value.getValueGroup() == valueGroup + && matchedEntityGroup == entityGroup + && matchedValueGroup == valueGroup), + (value, entity, entityGroup, valueGroup) -> entity, + (entity, value, entityGroup, valueGroup) -> new Quadruple<>(value, entity, entityGroup, valueGroup)); + } + + @Override + @TestTemplate + public void filter_2_changed() { + assertPrecomputeFilterChanged( + (precomputeFactory, entityGroup, valueGroup) -> precomputeFactory.forEachUnfiltered(TestdataLavishValue.class) + .join(TestdataLavishEntityGroup.class) + .join(TestdataLavishEntity.class) + .join(TestdataLavishValueGroup.class) + .filter((value, matchedEntityGroup, entity, matchedValueGroup) -> entity.getEntityGroup() == entityGroup + && value.getValueGroup() == valueGroup + && matchedEntityGroup == entityGroup + && matchedValueGroup == valueGroup), + (value, entityGroup, entity, valueGroup) -> entity, + (entity, value, entityGroup, valueGroup) -> new Quadruple<>(value, entityGroup, entity, valueGroup)); + } + + @Override + @TestTemplate + public void filter_3_changed() { + assertPrecomputeFilterChanged( + (precomputeFactory, entityGroup, valueGroup) -> precomputeFactory.forEachUnfiltered(TestdataLavishValue.class) + .join(TestdataLavishEntityGroup.class) + .join(TestdataLavishValueGroup.class) + .join(TestdataLavishEntity.class) + .filter((value, matchedEntityGroup, matchedValueGroup, entity) -> entity.getEntityGroup() == entityGroup + && value.getValueGroup() == valueGroup + && matchedEntityGroup == entityGroup + && matchedValueGroup == valueGroup), + (value, entityGroup, valueGroup, entity) -> entity, + (entity, value, entityGroup, valueGroup) -> new Quadruple<>(value, entityGroup, valueGroup, entity)); + } + + private void assertPrecompute(TestdataLavishSolution solution, + List> expectedValues, + Function> entityStreamSupplier) { + var scoreDirector = + buildScoreDirector(factory -> factory.precompute(entityStreamSupplier) + .ifExists(TestdataLavishEntity.class) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector); + + for (var entity : solution.getEntityList()) { + scoreDirector.beforeVariableChanged(entity, "value"); + entity.setValue(solution.getFirstValue()); + scoreDirector.afterVariableChanged(entity, "value"); + } + + assertScore(scoreDirector, expectedValues.stream() + .map(quad -> new Object[] { quad.a(), quad.b(), quad.c(), quad.d() }) + .map(AbstractConstraintStreamTest::assertMatch) + .toArray(AssertableMatch[]::new)); + } + + @Override + @TestTemplate + public void ifExists() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Quadruple<>(entityWithGroup, value, value, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .ifExists(TestdataLavishEntityGroup.class, Joiners.equal( + (a, b, c, d) -> a.getEntityGroup(), Function.identity()))); + } + + @Override + @TestTemplate + public void ifNotExists() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Quadruple<>(entityWithoutGroup, value, value, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .ifNotExists(TestdataLavishEntityGroup.class, Joiners.equal( + (a, b, c, d) -> a.getEntityGroup(), Function.identity()))); + } + + @Override + @TestTemplate + public void groupBy() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Quadruple<>(entityGroup, 1, 1, 1)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .filter(entity -> entity.getEntityGroup() != null) + .groupBy(TestdataLavishEntity::getEntityGroup, + ConstraintCollectors.count(), + ConstraintCollectors.count(), + ConstraintCollectors.count())); + } + + @Override + @TestTemplate + public void flattenLast() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Quadruple<>(entityWithoutGroup, value, value, value), + new Quadruple<>(entityWithGroup, value, value, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .groupBy(ConstraintCollectors.toList()) + .flattenLast(entityList -> entityList) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class)); + } + + @Override + @TestTemplate + public void map() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup1 = new TestdataLavishEntity(); + var entityWithGroup2 = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup1.setEntityGroup(entityGroup); + entityWithGroup2.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Quadruple<>(entityGroup, value, value, value), + new Quadruple<>(entityGroup, value, value, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .filter((entity, joinedValue1, joinedValue2, joinedValue3) -> entity.getEntityGroup() != null) + .map((entity, joinedValue1, joinedValue2, joinedValue3) -> entity.getEntityGroup(), + (entity, joinedValue1, joinedValue2, joinedValue3) -> joinedValue1, + (entity, joinedValue1, joinedValue2, joinedValue3) -> joinedValue2, + (entity, joinedValue1, joinedValue2, joinedValue3) -> joinedValue3)); + } + + @Override + @TestTemplate + public void concat() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, + List.of(new Quadruple<>(entityWithoutGroup, value, value, value), + new Quadruple<>(entityWithGroup, value, value, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .filter((entity, joinedValue1, joinedValue2, joinedValue3) -> entity.getEntityGroup() == null) + .concat(pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .filter((entity, joinedValue1, joinedValue2, + joinedValue3) -> entity.getEntityGroup() != null))); + } + + @Override + @TestTemplate + public void distinct() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup1 = new TestdataLavishEntity(); + var entityWithGroup2 = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup1.setEntityGroup(entityGroup); + entityWithGroup2.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Quadruple<>(entityGroup, value, value, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .filter((entity, joinedValue1, joinedValue2, joinedValue3) -> entity.getEntityGroup() != null) + .map((entity, joinedValue1, joinedValue2, joinedValue3) -> entity.getEntityGroup(), + (entity, joinedValue1, joinedValue2, joinedValue3) -> joinedValue1, + (entity, joinedValue1, joinedValue2, joinedValue3) -> joinedValue2, + (entity, joinedValue1, joinedValue2, joinedValue3) -> joinedValue3) + .distinct()); + } + + @Override + @TestTemplate + public void complement() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup1 = new TestdataLavishEntity(); + var entityWithGroup2 = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup1.setEntityGroup(entityGroup); + entityWithGroup2.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of( + new Quadruple<>(entityWithGroup1, value, value, value), + new Quadruple<>(entityWithGroup2, value, value, value), + new Quadruple<>(entityWithoutGroup, null, null, null)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .filter((entity, joinedValue1, joinedValue2, joinedValue3) -> entity.getEntityGroup() != null) + .complement(TestdataLavishEntity.class)); + } +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/tri/AbstractTriConstraintStreamNodeSharingTest.java b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/tri/AbstractTriConstraintStreamNodeSharingTest.java index ee191d8b64..e29b24b4e9 100644 --- a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/tri/AbstractTriConstraintStreamNodeSharingTest.java +++ b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/tri/AbstractTriConstraintStreamNodeSharingTest.java @@ -21,6 +21,7 @@ import ai.timefold.solver.core.testdomain.TestdataSolution; import ai.timefold.solver.core.testdomain.TestdataValue; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.TestTemplate; @@ -573,4 +574,38 @@ public void sameSourcesConcat() { .concat(baseStream.filter(filter1))) .isSameAs(baseStream.concat(baseStream.filter(filter1))); } + + @Override + @TestTemplate + public void sameDataPrecompute() { + TriPredicate filter1 = (a, b, c) -> true; + Assertions.assertThat((TriConstraintStream) constraintFactory.precompute( + precomputeFactory -> precomputeFactory.forEachUnfiltered(TestdataEntity.class) + .join(TestdataEntity.class) + .join(TestdataEntity.class) + .filter(filter1))) + .isSameAs(constraintFactory.precompute( + precomputeFactory -> precomputeFactory.forEachUnfiltered(TestdataEntity.class) + .join(TestdataEntity.class) + .join(TestdataEntity.class) + .filter(filter1))); + } + + @Override + @TestTemplate + public void differentDataPrecompute() { + TriPredicate filter1 = (a, b, c) -> true; + TriPredicate filter2 = (a, b, c) -> false; + + Assertions.assertThat((TriConstraintStream) constraintFactory.precompute( + precomputeFactory -> precomputeFactory.forEachUnfiltered(TestdataEntity.class) + .join(TestdataEntity.class) + .join(TestdataEntity.class) + .filter(filter1))) + .isNotSameAs(constraintFactory.precompute( + precomputeFactory -> precomputeFactory.forEachUnfiltered(TestdataEntity.class) + .join(TestdataEntity.class) + .join(TestdataEntity.class) + .filter(filter2))); + } } diff --git a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/tri/AbstractTriConstraintStreamPrecomputeTest.java b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/tri/AbstractTriConstraintStreamPrecomputeTest.java new file mode 100644 index 0000000000..a99559da99 --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/tri/AbstractTriConstraintStreamPrecomputeTest.java @@ -0,0 +1,380 @@ +package ai.timefold.solver.core.impl.score.stream.common.tri; + +import java.util.List; +import java.util.function.Function; + +import ai.timefold.solver.core.api.function.TriFunction; +import ai.timefold.solver.core.api.score.buildin.simple.SimpleScore; +import ai.timefold.solver.core.api.score.stream.ConstraintCollectors; +import ai.timefold.solver.core.api.score.stream.Joiners; +import ai.timefold.solver.core.api.score.stream.PrecomputeFactory; +import ai.timefold.solver.core.api.score.stream.tri.TriConstraintStream; +import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStreamTest; +import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamImplSupport; +import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamPrecomputeTest; +import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamTestExtension; +import ai.timefold.solver.core.impl.util.Triple; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishEntity; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishEntityGroup; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishSolution; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishValue; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishValueGroup; + +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.api.parallel.ExecutionMode; +import org.mockito.Mockito; + +@ExtendWith(ConstraintStreamTestExtension.class) +@Execution(ExecutionMode.CONCURRENT) +public abstract class AbstractTriConstraintStreamPrecomputeTest extends AbstractConstraintStreamTest + implements ConstraintStreamPrecomputeTest { + protected AbstractTriConstraintStreamPrecomputeTest(ConstraintStreamImplSupport implSupport) { + super(implSupport); + } + + private void assertPrecomputeFilterChanged( + TriFunction> precomputeStream, + TriFunction entityPicker, + TriFunction> inputDataToTuple) { + var solution = TestdataLavishSolution.generateSolution(); + var entityGroup = new TestdataLavishEntityGroup("MyEntityGroup"); + var valueGroup = new TestdataLavishValueGroup("MyValueGroup"); + solution.getEntityGroupList().add(entityGroup); + solution.getValueGroupList().add(valueGroup); + + var value1 = Mockito.spy(new TestdataLavishValue("MyValue 1", valueGroup)); + solution.getValueList().add(value1); + var value2 = Mockito.spy(new TestdataLavishValue("MyValue 2", valueGroup)); + solution.getValueList().add(value2); + var value3 = Mockito.spy(new TestdataLavishValue("MyValue 3", null)); + solution.getValueList().add(value3); + + var entity1 = Mockito.spy(new TestdataLavishEntity("MyEntity 1", entityGroup, value1)); + solution.getEntityList().add(entity1); + var entity2 = new TestdataLavishEntity("MyEntity 2", entityGroup, value1); + solution.getEntityList().add(entity2); + var entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value1); + solution.getEntityList().add(entity3); + + var scoreDirector = + buildScoreDirector(factory -> factory.precompute(pf -> precomputeStream.apply(pf, entityGroup, valueGroup)) + .filter((a, b, c) -> entityPicker.apply(a, b, c).getValue() == value1) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + var createMatch = + (TriFunction) (entity, + value, matchEntityGroup) -> { + var tuple = inputDataToTuple.apply(entity, value, matchEntityGroup); + return assertMatch(tuple.a(), tuple.b(), tuple.c()); + }; + Mockito.reset(entity1); + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + createMatch.apply(entity1, value1, entityGroup), + createMatch.apply(entity1, value2, entityGroup), + createMatch.apply(entity2, value1, entityGroup), + createMatch.apply(entity2, value2, entityGroup)); + Mockito.verify(entity1, Mockito.atLeastOnce()).getEntityGroup(); + + // Incrementally update a variable + Mockito.reset(entity1); + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(solution.getFirstValue()); + scoreDirector.afterVariableChanged(entity1, "value"); + assertScore(scoreDirector, + createMatch.apply(entity2, value1, entityGroup), + createMatch.apply(entity2, value2, entityGroup)); + Mockito.verify(entity1, Mockito.never()).getEntityGroup(); + + // Incrementally update a fact + scoreDirector.beforeProblemPropertyChanged(entity3); + entity3.setEntityGroup(entityGroup); + scoreDirector.afterProblemPropertyChanged(entity3); + assertScore(scoreDirector, + createMatch.apply(entity2, value1, entityGroup), + createMatch.apply(entity2, value2, entityGroup), + createMatch.apply(entity3, value1, entityGroup), + createMatch.apply(entity3, value2, entityGroup)); + + // Remove entity + scoreDirector.beforeEntityRemoved(entity3); + solution.getEntityList().remove(entity3); + scoreDirector.afterEntityRemoved(entity3); + assertScore(scoreDirector, + createMatch.apply(entity2, value1, entityGroup), + createMatch.apply(entity2, value2, entityGroup)); + + // Add it back again, to make sure it was properly removed before + scoreDirector.beforeEntityAdded(entity3); + solution.getEntityList().add(entity3); + scoreDirector.afterEntityAdded(entity3); + assertScore(scoreDirector, + createMatch.apply(entity2, value1, entityGroup), + createMatch.apply(entity2, value2, entityGroup), + createMatch.apply(entity3, value1, entityGroup), + createMatch.apply(entity3, value2, entityGroup)); + } + + @Override + @TestTemplate + public void filter_0_changed() { + assertPrecomputeFilterChanged( + (precomputeFactory, entityGroup, valueGroup) -> precomputeFactory.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishEntityGroup.class) + .filter((entity, value, matchedEntityGroup) -> entity.getEntityGroup() == entityGroup + && value.getValueGroup() == valueGroup + && matchedEntityGroup == entityGroup), + (entity, value, entityGroup) -> entity, + Triple::new); + } + + @Override + @TestTemplate + public void filter_1_changed() { + assertPrecomputeFilterChanged( + (precomputeFactory, entityGroup, valueGroup) -> precomputeFactory.forEachUnfiltered(TestdataLavishValue.class) + .join(TestdataLavishEntity.class) + .join(TestdataLavishEntityGroup.class) + .filter((value, entity, matchedEntityGroup) -> entity.getEntityGroup() == entityGroup + && value.getValueGroup() == valueGroup + && matchedEntityGroup == entityGroup), + (value, entity, entityGroup) -> entity, + (entity, value, entityGroup) -> new Triple<>(value, entity, entityGroup)); + } + + @Override + @TestTemplate + public void filter_2_changed() { + assertPrecomputeFilterChanged( + (precomputeFactory, entityGroup, valueGroup) -> precomputeFactory.forEachUnfiltered(TestdataLavishValue.class) + .join(TestdataLavishEntityGroup.class) + .join(TestdataLavishEntity.class) + .filter((value, matchedEntityGroup, entity) -> entity.getEntityGroup() == entityGroup + && value.getValueGroup() == valueGroup + && matchedEntityGroup == entityGroup), + (value, entityGroup, entity) -> entity, + (entity, value, entityGroup) -> new Triple<>(value, entityGroup, entity)); + } + + private void assertPrecompute(TestdataLavishSolution solution, + List> expectedValues, + Function> entityStreamSupplier) { + var scoreDirector = + buildScoreDirector(factory -> factory.precompute(entityStreamSupplier) + .ifExists(TestdataLavishEntity.class) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector); + + for (var entity : solution.getEntityList()) { + scoreDirector.beforeVariableChanged(entity, "value"); + entity.setValue(solution.getFirstValue()); + scoreDirector.afterVariableChanged(entity, "value"); + } + + assertScore(scoreDirector, expectedValues.stream() + .map(triple -> new Object[] { triple.a(), triple.b(), triple.c() }) + .map(AbstractConstraintStreamTest::assertMatch) + .toArray(AssertableMatch[]::new)); + } + + @Override + @TestTemplate + public void ifExists() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Triple<>(entityWithGroup, value, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .ifExists(TestdataLavishEntityGroup.class, Joiners.equal( + (a, b, c) -> a.getEntityGroup(), Function.identity()))); + } + + @Override + @TestTemplate + public void ifNotExists() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Triple<>(entityWithoutGroup, value, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .ifNotExists(TestdataLavishEntityGroup.class, Joiners.equal( + (a, b, c) -> a.getEntityGroup(), Function.identity()))); + } + + @Override + @TestTemplate + public void groupBy() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Triple<>(entityGroup, 1, 1)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .filter(entity -> entity.getEntityGroup() != null) + .groupBy(TestdataLavishEntity::getEntityGroup, + ConstraintCollectors.count(), + ConstraintCollectors.count())); + } + + @Override + @TestTemplate + public void flattenLast() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Triple<>(entityWithoutGroup, value, value), + new Triple<>(entityWithGroup, value, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .groupBy(ConstraintCollectors.toList()) + .flattenLast(entityList -> entityList) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class)); + } + + @Override + @TestTemplate + public void map() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup1 = new TestdataLavishEntity(); + var entityWithGroup2 = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup1.setEntityGroup(entityGroup); + entityWithGroup2.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Triple<>(entityGroup, value, value), + new Triple<>(entityGroup, value, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .filter((entity, joinedValue1, joinedValue2) -> entity.getEntityGroup() != null) + .map((entity, joinedValue1, joinedValue2) -> entity.getEntityGroup(), + (entity, joinedValue1, joinedValue2) -> joinedValue1, + (entity, joinedValue1, joinedValue2) -> joinedValue2)); + } + + @Override + @TestTemplate + public void concat() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, + List.of(new Triple<>(entityWithoutGroup, value, value), new Triple<>(entityWithGroup, value, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .filter((entity, joinedValue1, joinedValue2) -> entity.getEntityGroup() == null) + .concat(pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .filter((entity, joinedValue1, joinedValue2) -> entity.getEntityGroup() != null))); + } + + @Override + @TestTemplate + public void distinct() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup1 = new TestdataLavishEntity(); + var entityWithGroup2 = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup1.setEntityGroup(entityGroup); + entityWithGroup2.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of(new Triple<>(entityGroup, value, value)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .filter((entity, joinedValue1, joinedValue2) -> entity.getEntityGroup() != null) + .map((entity, joinedValue1, joinedValue2) -> entity.getEntityGroup(), + (entity, joinedValue1, joinedValue2) -> joinedValue1, + (entity, joinedValue1, joinedValue2) -> joinedValue2) + .distinct()); + } + + @Override + @TestTemplate + public void complement() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup1 = new TestdataLavishEntity(); + var entityWithGroup2 = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup1.setEntityGroup(entityGroup); + entityWithGroup2.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2)); + solution.getEntityGroupList().add(entityGroup); + var value = new TestdataLavishValue(); + solution.getValueList().add(value); + + assertPrecompute(solution, List.of( + new Triple<>(entityWithGroup1, value, value), + new Triple<>(entityWithGroup2, value, value), + new Triple<>(entityWithoutGroup, null, null)), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .join(TestdataLavishValue.class) + .join(TestdataLavishValue.class) + .filter((entity, joinedValue1, joinedValue2) -> entity.getEntityGroup() != null) + .complement(TestdataLavishEntity.class) + .distinct()); + } +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/uni/AbstractUniConstraintStreamNodeSharingTest.java b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/uni/AbstractUniConstraintStreamNodeSharingTest.java index e8309999e5..4887b36d39 100644 --- a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/uni/AbstractUniConstraintStreamNodeSharingTest.java +++ b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/uni/AbstractUniConstraintStreamNodeSharingTest.java @@ -4,6 +4,7 @@ import java.util.Collections; import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.function.Predicate; import java.util.function.ToIntFunction; @@ -20,6 +21,7 @@ import ai.timefold.solver.core.testdomain.TestdataSolution; import ai.timefold.solver.core.testdomain.TestdataValue; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.TestTemplate; @@ -40,6 +42,27 @@ public void setup() { baseStream = constraintFactory.forEach(TestdataEntity.class); } + // ************************************************************************ + // ForEach + // ************************************************************************ + @TestTemplate + public void sameRetrivalSemanticsForEach() { + assertThat(constraintFactory.forEachUnfiltered(TestdataEntity.class)) + .isSameAs(constraintFactory.forEachUnfiltered(TestdataEntity.class)); + } + + @TestTemplate + public void differentRetrivalSemanticsForEach() { + var precomputeStream = new AtomicReference>(); + constraintFactory.precompute(pf -> { + var out = pf.forEachUnfiltered(TestdataEntity.class); + precomputeStream.set(out); + return out; + }); + assertThat(constraintFactory.forEachUnfiltered(TestdataEntity.class)) + .isNotSameAs(precomputeStream.get()); + } + // ************************************************************************ // Filter // ************************************************************************ @@ -695,4 +718,30 @@ public void sameSourcesConcat() { .concat(baseStream.filter(filter1))) .isSameAs(baseStream.concat(baseStream.filter(filter1))); } + + @Override + @TestTemplate + public void sameDataPrecompute() { + Predicate filter1 = a -> true; + Assertions.assertThat((UniConstraintStream) constraintFactory.precompute( + precomputeFactory -> precomputeFactory.forEachUnfiltered(TestdataEntity.class) + .filter(filter1))) + .isSameAs(constraintFactory.precompute( + precomputeFactory -> precomputeFactory.forEachUnfiltered(TestdataEntity.class) + .filter(filter1))); + } + + @Override + @TestTemplate + public void differentDataPrecompute() { + Predicate filter1 = a -> true; + Predicate filter2 = a -> false; + + Assertions.assertThat((UniConstraintStream) constraintFactory.precompute( + precomputeFactory -> precomputeFactory.forEachUnfiltered(TestdataEntity.class) + .filter(filter1))) + .isNotSameAs(constraintFactory.precompute( + precomputeFactory -> precomputeFactory.forEachUnfiltered(TestdataEntity.class) + .filter(filter2))); + } } diff --git a/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/uni/AbstractUniConstraintStreamPrecomputeTest.java b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/uni/AbstractUniConstraintStreamPrecomputeTest.java new file mode 100644 index 0000000000..fc12a68996 --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/score/stream/common/uni/AbstractUniConstraintStreamPrecomputeTest.java @@ -0,0 +1,274 @@ +package ai.timefold.solver.core.impl.score.stream.common.uni; + +import java.util.List; +import java.util.function.Function; + +import ai.timefold.solver.core.api.score.buildin.simple.SimpleScore; +import ai.timefold.solver.core.api.score.stream.ConstraintCollectors; +import ai.timefold.solver.core.api.score.stream.Joiners; +import ai.timefold.solver.core.api.score.stream.PrecomputeFactory; +import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream; +import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStreamTest; +import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamImplSupport; +import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamPrecomputeTest; +import ai.timefold.solver.core.impl.score.stream.common.ConstraintStreamTestExtension; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishEntity; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishEntityGroup; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishSolution; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishValue; + +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.api.parallel.ExecutionMode; +import org.mockito.Mockito; + +@ExtendWith(ConstraintStreamTestExtension.class) +@Execution(ExecutionMode.CONCURRENT) +public abstract class AbstractUniConstraintStreamPrecomputeTest extends AbstractConstraintStreamTest + implements ConstraintStreamPrecomputeTest { + protected AbstractUniConstraintStreamPrecomputeTest(ConstraintStreamImplSupport implSupport) { + super(implSupport); + } + + @Override + @TestTemplate + public void filter_0_changed() { + var solution = TestdataLavishSolution.generateSolution(); + var entityGroup = new TestdataLavishEntityGroup("MyEntityGroup"); + solution.getEntityGroupList().add(entityGroup); + var entity1 = Mockito.spy(new TestdataLavishEntity("MyEntity 1", entityGroup, solution.getFirstValue())); + solution.getEntityList().add(entity1); + var entity2 = new TestdataLavishEntity("MyEntity 2", entityGroup, solution.getFirstValue()); + solution.getEntityList().add(entity2); + var entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + solution.getFirstValue()); + solution.getEntityList().add(entity3); + + var scoreDirector = + buildScoreDirector(factory -> factory.precompute(data -> data.forEachUnfiltered(TestdataLavishEntity.class) + .filter(entity -> entity.getEntityGroup() == entityGroup)) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + Mockito.reset(entity1); + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2)); + Mockito.verify(entity1, Mockito.atLeastOnce()).getEntityGroup(); + + // Incrementally update a variable + Mockito.reset(entity1); + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(new TestdataLavishValue()); + scoreDirector.afterVariableChanged(entity1, "value"); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2)); + Mockito.verify(entity1, Mockito.never()).getEntityGroup(); + + // Incrementally update a fact + scoreDirector.beforeProblemPropertyChanged(entity3); + entity3.setEntityGroup(entityGroup); + scoreDirector.afterProblemPropertyChanged(entity3); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2), + assertMatch(entity3)); + + // Remove entity + scoreDirector.beforeEntityRemoved(entity3); + solution.getEntityList().remove(entity3); + scoreDirector.afterEntityRemoved(entity3); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2)); + + // Add it back again, to make sure it was properly removed before + scoreDirector.beforeEntityAdded(entity3); + solution.getEntityList().add(entity3); + scoreDirector.afterEntityAdded(entity3); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2), + assertMatch(entity3)); + } + + private void assertPrecompute(TestdataLavishSolution solution, + List expectedValues, + Function> entityStreamSupplier) { + var scoreDirector = + buildScoreDirector(factory -> factory.precompute(entityStreamSupplier) + .ifExists(TestdataLavishEntity.class) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector); + + for (var entity : solution.getEntityList()) { + scoreDirector.beforeVariableChanged(entity, "value"); + entity.setValue(solution.getFirstValue()); + scoreDirector.afterVariableChanged(entity, "value"); + } + + assertScore(scoreDirector, expectedValues.stream() + .map(AbstractConstraintStreamTest::assertMatch) + .toArray(AssertableMatch[]::new)); + } + + @Override + @TestTemplate + public void ifExists() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + solution.getValueList().add(new TestdataLavishValue()); + + assertPrecompute(solution, List.of(entityWithGroup), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .ifExists(TestdataLavishEntityGroup.class, Joiners.equal( + TestdataLavishEntity::getEntityGroup, Function.identity()))); + } + + @Override + @TestTemplate + public void ifNotExists() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + solution.getValueList().add(new TestdataLavishValue()); + + assertPrecompute(solution, List.of(entityWithoutGroup), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .ifNotExists(TestdataLavishEntityGroup.class, Joiners.equal( + TestdataLavishEntity::getEntityGroup, Function.identity()))); + } + + @Override + @TestTemplate + public void groupBy() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + solution.getValueList().add(new TestdataLavishValue()); + + assertPrecompute(solution, List.of(entityGroup), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .filter(entity -> entity.getEntityGroup() != null) + .groupBy(TestdataLavishEntity::getEntityGroup)); + } + + @Override + @TestTemplate + public void flattenLast() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + solution.getValueList().add(new TestdataLavishValue()); + + assertPrecompute(solution, List.of(entityWithoutGroup, entityWithGroup), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .groupBy(ConstraintCollectors.toList()) + .flattenLast(entityList -> entityList)); + } + + @Override + @TestTemplate + public void map() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup1 = new TestdataLavishEntity(); + var entityWithGroup2 = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup1.setEntityGroup(entityGroup); + entityWithGroup2.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2)); + solution.getEntityGroupList().add(entityGroup); + solution.getValueList().add(new TestdataLavishValue()); + + assertPrecompute(solution, List.of(entityGroup, entityGroup), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .filter(entity -> entity.getEntityGroup() != null) + .map(TestdataLavishEntity::getEntityGroup)); + } + + @Override + @TestTemplate + public void concat() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup)); + solution.getEntityGroupList().add(entityGroup); + solution.getValueList().add(new TestdataLavishValue()); + + assertPrecompute(solution, List.of(entityWithoutGroup, entityWithGroup), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .filter(entity -> entity.getEntityGroup() == null) + .concat(pf.forEachUnfiltered(TestdataLavishEntity.class) + .filter(entity -> entity.getEntityGroup() != null))); + } + + @Override + @TestTemplate + public void distinct() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup1 = new TestdataLavishEntity(); + var entityWithGroup2 = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup1.setEntityGroup(entityGroup); + entityWithGroup2.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2)); + solution.getEntityGroupList().add(entityGroup); + solution.getValueList().add(new TestdataLavishValue()); + + assertPrecompute(solution, List.of(entityGroup), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .filter(entity -> entity.getEntityGroup() != null) + .map(TestdataLavishEntity::getEntityGroup) + .distinct()); + } + + @Override + @TestTemplate + public void complement() { + var solution = TestdataLavishSolution.generateEmptySolution(); + var entityWithoutGroup = new TestdataLavishEntity(); + var entityWithGroup1 = new TestdataLavishEntity(); + var entityWithGroup2 = new TestdataLavishEntity(); + var entityGroup = new TestdataLavishEntityGroup(); + entityWithGroup1.setEntityGroup(entityGroup); + entityWithGroup2.setEntityGroup(entityGroup); + solution.getEntityList().addAll(List.of(entityWithoutGroup, entityWithGroup1, entityWithGroup2)); + solution.getEntityGroupList().add(entityGroup); + solution.getValueList().add(new TestdataLavishValue()); + + assertPrecompute(solution, List.of(entityWithGroup1, entityWithGroup2, entityWithoutGroup), + pf -> pf.forEachUnfiltered(TestdataLavishEntity.class) + .filter(entity -> entity.getEntityGroup() != null) + .complement(TestdataLavishEntity.class)); + } +} diff --git a/core/src/test/java/ai/timefold/solver/core/testconstraint/TestConstraintFactory.java b/core/src/test/java/ai/timefold/solver/core/testconstraint/TestConstraintFactory.java index 8e269e8122..434fd27a2f 100644 --- a/core/src/test/java/ai/timefold/solver/core/testconstraint/TestConstraintFactory.java +++ b/core/src/test/java/ai/timefold/solver/core/testconstraint/TestConstraintFactory.java @@ -1,8 +1,11 @@ package ai.timefold.solver.core.testconstraint; import java.util.Objects; +import java.util.function.Function; import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.api.score.stream.ConstraintStream; +import ai.timefold.solver.core.api.score.stream.PrecomputeFactory; import ai.timefold.solver.core.api.score.stream.uni.UniConstraintStream; import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor; import ai.timefold.solver.core.impl.score.stream.common.InnerConstraintFactory; @@ -43,6 +46,12 @@ public SolutionDescriptor getSolutionDescriptor() { throw new UnsupportedOperationException(); } + @Override + public @NonNull Stream_ + precompute(@NonNull Function<@NonNull PrecomputeFactory, @NonNull Stream_> stream) { + throw new UnsupportedOperationException(); + } + @Override public @NonNull UniConstraintStream from(@NonNull Class fromClass) { throw new UnsupportedOperationException(); diff --git a/core/src/test/java/ai/timefold/solver/core/testdomain/score/lavish/TestdataLavishSolution.java b/core/src/test/java/ai/timefold/solver/core/testdomain/score/lavish/TestdataLavishSolution.java index e688ec9a56..b97898955f 100644 --- a/core/src/test/java/ai/timefold/solver/core/testdomain/score/lavish/TestdataLavishSolution.java +++ b/core/src/test/java/ai/timefold/solver/core/testdomain/score/lavish/TestdataLavishSolution.java @@ -19,6 +19,10 @@ public static SolutionDescriptor buildSolutionDescriptor return SolutionDescriptor.buildSolutionDescriptor(TestdataLavishSolution.class, TestdataLavishEntity.class); } + public static TestdataLavishSolution generateEmptySolution() { + return generateSolution(0, 0, 0, 0); + } + public static TestdataLavishSolution generateSolution() { return generateSolution(2, 5, 3, 7); } diff --git a/spring-integration/spring-boot-autoconfigure/src/test/java/ai/timefold/solver/spring/boot/autoconfigure/TimefoldSolverMultipleSolverAutoConfigurationTest.java b/spring-integration/spring-boot-autoconfigure/src/test/java/ai/timefold/solver/spring/boot/autoconfigure/TimefoldSolverMultipleSolverAutoConfigurationTest.java index 75989e3916..5195b222c8 100644 --- a/spring-integration/spring-boot-autoconfigure/src/test/java/ai/timefold/solver/spring/boot/autoconfigure/TimefoldSolverMultipleSolverAutoConfigurationTest.java +++ b/spring-integration/spring-boot-autoconfigure/src/test/java/ai/timefold/solver/spring/boot/autoconfigure/TimefoldSolverMultipleSolverAutoConfigurationTest.java @@ -320,10 +320,7 @@ void solverWithYaml() { assertThat(solution.getScore().score()).isNotNegative(); } }); - } - @Test - void invalidYaml() { assertThatCode(() -> contextRunner .withInitializer(new ConfigDataApplicationContextInitializer()) .withSystemProperties(