Skip to content
38 changes: 26 additions & 12 deletions cpmpy/expressions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,28 @@ def value(self):
elif self.name == ">=": return arg_vals[0] >= arg_vals[1]
return None # default

def get_bounds(self):
(lb1, ub1), (lb2, ub2) = get_bounds(self.args[0]), get_bounds(self.args[1])
if self.name == "==":
if lb1 == ub1 == lb2 == ub2: return (1,1) # equal domains, trivially true
if ub1 < lb2 or ub2 < lb1: return (0,0) # disjoint, trivially false
if self.name == "!=":
if ub1 < lb2 or ub2 < lb1: return (1,1) # disjoint, trivially true
if lb1 == ub1 == lb2 == ub2: return (0,0) # equal domains, trivially false
if self.name == "<=":
if ub1 <= lb2: return (1,1) # domain of lhs is leq domain of rhs
if lb1 > ub2: return (0,0) # domain of lhs is gt domain of rhs
if self.name == "<":
if ub1 < lb2: return (1,1) # domain of lhs is lt domain of rhs
if lb1 >= ub2: return (0,0) # domain of lhs is geq domain of rhs
if self.name == ">=":
if lb1 >= ub2: return (1,1) # domain of lhs is geq domain of rhs
if ub1 < lb2: return (0,0) # domain of lhs is lt domain of rhs
if self.name == ">":
if lb1 > ub2: return (1,1) # domain of lhs is gt domain of rhs
if ub1 <= lb2: return (0,0) # domain of lhs is leq domain of rhs
return (0,1)


class Operator(Expression):
"""
Expand Down Expand Up @@ -751,18 +773,10 @@ def get_bounds(self):
lowerbound, upperbound = sum(lbs), sum(ubs)
elif self.name == 'wsum':
weights, vars = self.args
bounds = []
lowerbound, upperbound = 0,0
#this may seem like too many lines, but avoiding np.sum avoids overflowing things at int32 bounds
for w, (lb, ub) in zip(weights, [get_bounds(arg) for arg in vars]):
x,y = int(w) * lb, int(w) * ub
if x <= y: # x is the lb of this arg
lowerbound += x
upperbound += y
else:
lowerbound += y
upperbound += x

lbs, ubs = get_bounds(vars)
lbs, ubs = [w * lb for w,lb in zip(weights,lbs)], [w * ub for w, ub in zip(weights,ubs)]
lowerbound = sum(lb if lb <= ub else ub for lb,ub in zip(lbs,ubs))
upperbound = sum(ub if ub >= lb else lb for lb, ub in zip(lbs, ubs))
elif self.name == 'sub':
lb1, ub1 = get_bounds(self.args[0])
lb2, ub2 = get_bounds(self.args[1])
Expand Down
2 changes: 1 addition & 1 deletion cpmpy/expressions/globalconstraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def decompose(self):
decomp = [sum(self.args[:2]) == 1]
if len(self.args) > 2:
decomp = Xor([decomp,self.args[2:]]).decompose()[0]
return decomp, []
return cp.transformations.normalize.simplify_boolean(decomp), []

def value(self):
return sum(argvals(self.args)) % 2 == 1
Expand Down
11 changes: 10 additions & 1 deletion cpmpy/transformations/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ..expressions.core import BoolVal, Expression, Comparison, Operator
from ..expressions.globalfunctions import GlobalFunction
from ..expressions.utils import eval_comparison, is_false_cst, is_true_cst, is_boolexpr, is_num, is_bool
from ..expressions.utils import eval_comparison, is_false_cst, is_true_cst, is_boolexpr, is_num, is_bool, get_bounds
from ..expressions.variables import NDVarArray, _BoolVarImpl
from ..exceptions import NotSupportedError
from ..expressions.globalconstraints import GlobalConstraint
Expand Down Expand Up @@ -169,6 +169,15 @@ def simplify_boolean(lst_of_expr, num_context=False):
elif isinstance(expr, Comparison):
lhs, rhs = simplify_boolean(expr.args, num_context=True)
name = expr.name

lb, ub = get_bounds(eval_comparison(name, lhs, rhs))
if lb == 0 == ub:
newlist.append(0 if num_context else BoolVal(False))
continue
if lb == 1 == ub:
newlist.append(1 if num_context else BoolVal(True))
continue

if is_num(lhs) and is_boolexpr(rhs): # flip arguments of comparison to reduct nb of cases
if name == "<": name = ">"
elif name == ">": name = "<"
Expand Down
1 change: 1 addition & 0 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def global_constraints(solver):
if name == "Xor":
yield Xor(BOOL_ARGS)
yield Xor(BOOL_ARGS + [True,False])
yield Xor([True, BOOL_ARGS[0]])
continue
elif name == "Inverse":
expr = cls(NUM_ARGS, [1,0,2])
Expand Down
26 changes: 25 additions & 1 deletion tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from cpmpy.expressions import *
from cpmpy.expressions.variables import NDVarArray
from cpmpy.expressions.core import Comparison, Operator, Expression
from cpmpy.expressions.utils import eval_comparison, get_bounds, argval
from cpmpy.expressions.utils import eval_comparison, get_bounds, argval, all_pairs


class TestComparison(unittest.TestCase):
def test_comps(self):
Expand Down Expand Up @@ -450,6 +451,29 @@ def test_bounds_unary(self):
self.assertGreaterEqual(val,lb)
self.assertLessEqual(val,ub)

def test_bounds_comparison(self):

x_00 = intvar(0,0, name="x00")
x_01 = intvar(0,1, name="x01")
x_12= intvar(1,2, name="x12")
x_23 = intvar(2,3, name="x23")

for x,y in all_pairs([0, x_00, x_01, x_12, x_23]):
for comp in ['==','!=','<=','<','>=','>']:
x_bounds = get_bounds(x)
y_bounds = get_bounds(y)

total_vals = len(range(x_bounds[0],x_bounds[1]+1)) * len(range(y_bounds[0],y_bounds[1]+1))

for expr in [Comparison(comp, x,y), Comparison(comp, y,x)]:
lb, ub = expr.get_bounds()

if lb == 0 == ub:
self.assertEqual(cp.Model(expr).solveAll(), 0)
elif lb == 1 == ub:
self.assertEqual(cp.Model(expr).solveAll(), total_vals)
else:
self.assertNotEqual(cp.Model(expr).solveAll(), total_vals)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use assertLess()?
What with assertMore(..., 0)?


def test_incomplete_func(self):
# element constraint
Expand Down
20 changes: 10 additions & 10 deletions tests/test_flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,18 @@ def test_constraint(self):
self.assertEqual( str(flatten_constraint( x&y&~z )), "[BV0, BV1, ~BV2]" )
self.assertEqual( str(flatten_constraint( x.implies(y) )), "[(BV0) -> (BV1)]" )
self.assertEqual( str(flatten_constraint( x|(y.implies(z)) )), "[or([BV0, ~BV1, BV2])]" )
self.assertEqual( str(flatten_constraint( (a > 10)&x )), "[IV0 > 10, BV0]" )
self.assertEqual( str(flatten_constraint( (a > 8)&x )), "[IV0 > 8, BV0]" )
cp.boolvar() # increase counter
self.assertEqual( str(flatten_constraint( (a > 10).implies(x) )), "[(IV0 > 10) -> (BV0)]" )
self.assertEqual( str(flatten_constraint( (a > 8).implies(x) )), "[(IV0 > 8) -> (BV0)]" )
cp.boolvar() # increase counter
self.assertEqual( str(flatten_constraint( (a > 10) )), "[IV0 > 10]" )
self.assertEqual( str(flatten_constraint( (a > 10) == 1 )), "[IV0 > 10]" )
self.assertEqual( str(flatten_constraint( (a > 10) == 0 )), "[IV0 <= 10]" )
self.assertEqual( str(flatten_constraint( (a > 10) == x )), "[(IV0 > 10) == (BV0)]" )
self.assertEqual( str(flatten_constraint( (a > 8) )), "[IV0 > 8]" )
self.assertEqual( str(flatten_constraint( (a > 8) == 1 )), "[IV0 > 8]" )
self.assertEqual( str(flatten_constraint( (a > 8) == 0 )), "[IV0 <= 8]" )
self.assertEqual( str(flatten_constraint( (a > 8) == x )), "[(IV0 > 8) == (BV0)]" )
#self.assertEqual( str(flatten_constraint( x == (a > 10) )), "[(IV0 > 10) == (BV0)]" ) # TODO, make it do the swap (again)
self.assertEqual( str(flatten_constraint( (a > 10) | (b + c > 2) )), "[(BV5) or (BV6), (IV0 > 10) == (BV5), ((IV1) + (IV2) > 2) == (BV6)]" )
self.assertEqual( str(flatten_constraint( a > 10 )), "[IV0 > 10]" )
self.assertEqual( str(flatten_constraint( 10 > a )), "[IV0 < 10]" ) # surprising
self.assertEqual( str(flatten_constraint( (a > 8) | (b + c > 2) )), "[(BV5) or (BV6), (IV0 > 8) == (BV5), ((IV1) + (IV2) > 2) == (BV6)]" )
self.assertEqual( str(flatten_constraint( a > 8 )), "[IV0 > 8]" )
self.assertEqual( str(flatten_constraint( 8 > a )), "[IV0 < 8]" ) # surprising
self.assertEqual( str(flatten_constraint( a+b > c )), "[((IV0) + (IV1)) > (IV2)]" )
#self.assertEqual( str(flatten_constraint( c < a+b )), "[((IV0) + (IV1)) > (IV2)]" ) # TODO, make it do the swap (again)
self.assertEqual( str(flatten_constraint( (a+b > c) == x|y )), "[(((IV0) + (IV1)) > (IV2)) == (BV7), ((BV0) or (BV1)) == (BV7)]" )
Expand Down Expand Up @@ -213,7 +213,7 @@ def test_constraint(self):
self.assertEqual( str(a % 1 == 0), "(IV0) mod 1 == 0" )

# boolexpr as numexpr
self.assertEqual( str(flatten_constraint((a + b == 2) <= c)), "[(BV11) <= (IV2), ((IV0) + (IV1) == 2) == (BV11)]" )
self.assertEqual( str(flatten_constraint((a + b == 2) < c)), "[(BV11) < (IV2), ((IV0) + (IV1) == 2) == (BV11)]" )

# != in boolexpr, bug #170
self.assertEqual( str(normalized_boolexpr(x != (a == 1))), "((BV12) == (~BV0), [(IV0 == 1) == (BV12)])" )
Expand Down
17 changes: 17 additions & 0 deletions tests/test_globalconstraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,23 @@ def test_xor_with_constants(self):
self.assertFalse(cp.Model(cp.Xor([False, False])).solve())
self.assertFalse(cp.Model(cp.Xor([False, False, False])).solve())

def test_issue_620(self):
a = cp.boolvar()
b = cp.boolvar()
c = cp.boolvar()

model = cp.Model(cp.Xor([(cp.Xor([a, b, c])) <= True, ~((cp.Xor([a, b, c])) <= True)]))

self.assertTrue(model.solve(solver='ortools'))
if "minizinc" in cp.SolverLookup.supported():
self.assertTrue(model.solve(solver='minizinc'))
if "z3" in cp.SolverLookup.supported():
self.assertTrue(model.solve(solver='z3'))
if "choco" in cp.SolverLookup.supported():
self.assertTrue(model.solve(solver='choco'))
if "gurobi" in cp.SolverLookup.supported():
self.assertTrue(model.solve(solver='gurobi'))

def test_ite_with_constants(self):
x,y,z = cp.boolvar(shape=3)
expr = cp.IfThenElse(True, y, z)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def test_installed_solvers(self, solver):
model.solve(solver=solver)
assert [int(a) for a in v.value()] == [0, 1, 0]

s = cp.SolverLookup.get(solver)
s = cp.SolverLookup.get(solver, model)
s.solve()
assert [int(a) for a in v.value()] == [0, 1, 0]

Expand Down
22 changes: 11 additions & 11 deletions tests/test_trans_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class TransSimplify(unittest.TestCase):

def setUp(self) -> None:
self.bvs = cp.boolvar(shape=3, name="bv")
self.ivs = cp.intvar(0, 5, shape=3, name="iv")
self.ivs = cp.intvar(-1, 5, shape=3, name="iv")

self.transform = lambda x: simplify_boolean(toplevel_list(x))

Expand All @@ -19,10 +19,10 @@ def test_bool_ops(self):
expr = Operator("or", self.bvs.tolist() + [True])
self.assertEqual(str(self.transform(expr)), "[boolval(True)]")

expr = Operator("and", self.bvs.tolist() + [False]) + self.ivs[0] >= 10
self.assertEqual(str(self.transform(expr)), "[0 + (iv[0]) >= 10]")
expr = Operator("and", self.bvs.tolist() + [True]) + self.ivs[0] >= 10
self.assertEqual(str(self.transform(expr)), "[(and([bv[0], bv[1], bv[2]])) + (iv[0]) >= 10]")
expr = Operator("and", self.bvs.tolist() + [False]) + self.ivs[0] >= 3
self.assertEqual(str(self.transform(expr)), "[0 + (iv[0]) >= 3]")
expr = Operator("and", self.bvs.tolist() + [True]) + self.ivs[0] >= 3
self.assertEqual(str(self.transform(expr)), "[(and([bv[0], bv[1], bv[2]])) + (iv[0]) >= 3]")


expr = Operator("->", [self.bvs[0], True])
Expand All @@ -35,16 +35,16 @@ def test_bool_ops(self):
self.assertEqual(str(self.transform(expr)), "[boolval(True)]")

def test_bool_in_comp(self):
expr = self.ivs[0] >= False
self.assertEqual(str(self.transform(expr)), '[iv[0] >= 0]')
expr = self.ivs[0] > False
self.assertEqual(str(self.transform(expr)), '[iv[0] > 0]')
expr = self.ivs[0] >= True
self.assertEqual(str(self.transform(expr)), '[iv[0] >= 1]')

expr = (cp.sum(self.ivs) + True) >= 10
self.assertEqual(str(self.transform(expr)), '[sum([iv[0], iv[1], iv[2], 1]) >= 10]')

expr = True + self.ivs[0] >= False
self.assertEqual(str(self.transform(expr)), '[1 + (iv[0]) >= 0]')
expr = True + self.ivs[0] > False
self.assertEqual(str(self.transform(expr)), '[1 + (iv[0]) > 0]')

def test_boolvar_comps(self):
num_args = {"<0": -1, "0": 0, "]0..1[": 0.5, "1": 1, ">0": 2}
Expand Down Expand Up @@ -87,8 +87,8 @@ def test_simplify_expressions(self):
# with constant, does not change (surprisingly? but we cannot check what the res type is...)
expr = cp.max(self.ivs.tolist() + [False]) == 0
self.assertEqual(str(self.transform(expr)), '[max(iv[0],iv[1],iv[2],boolval(False)) == 0]')
expr = 0 == cp.max(self.ivs.tolist() + [True])
self.assertEqual(str(self.transform(expr)), '[max(iv[0],iv[1],iv[2],boolval(True)) == 0]')
expr = 1 == cp.max(self.ivs.tolist() + [True])
self.assertEqual(str(self.transform(expr)), '[max(iv[0],iv[1],iv[2],boolval(True)) == 1]')

expr = (self.ivs[0] <= self.ivs[1]) == 0
self.assertEqual(str(self.transform(expr)), '[not([(iv[0]) <= (iv[1])])]')
Expand Down
2 changes: 1 addition & 1 deletion tests/test_transf_reif.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_reif_element(self):

def test_reif_rewrite(self):
bvs = boolvar(shape=4, name="bvs")
ivs = intvar(1,9, shape=3, name="ivs")
ivs = intvar(0,9, shape=3, name="ivs")
rv = boolvar(name="rv")
arr = cpm_array([0,1,2])

Expand Down
Loading