diff --git a/cpmpy/solvers/pindakaas.py b/cpmpy/solvers/pindakaas.py index 20540903a..5e4bfb3b5 100755 --- a/cpmpy/solvers/pindakaas.py +++ b/cpmpy/solvers/pindakaas.py @@ -230,7 +230,17 @@ def solver_var(self, cpm_var): raise TypeError def transform(self, cpm_expr): - cpm_cons = toplevel_list(cpm_expr) + cpm_cons = [cpm_expr] + + # eagerly encode integer variables to maximize CSE + for x in get_variables(cpm_expr): + if (not isinstance(x, _BoolVarImpl)) and isinstance(x, _IntVarImpl): + _, cons = _encode_int_var( + self.ivarmap, x, _decide_encoding(x, encoding=self.encoding), csemap=self._csemap + ) + cpm_cons += cons + + cpm_cons = toplevel_list(cpm_cons) cpm_cons = no_partial_functions(cpm_cons) cpm_cons = decompose_in_tree(cpm_cons, csemap=self._csemap) cpm_cons = simplify_boolean(cpm_cons) @@ -240,7 +250,7 @@ def transform(self, cpm_expr): cpm_cons = linearize_constraint( cpm_cons, supported=frozenset({"sum", "wsum", "and", "or"}), csemap=self._csemap ) - cpm_cons = int2bool(cpm_cons, self.ivarmap, encoding=self.encoding) + cpm_cons = int2bool(cpm_cons, self.ivarmap, encoding=self.encoding, csemap=self._csemap) return cpm_cons def add(self, cpm_expr_orig): diff --git a/cpmpy/transformations/int2bool.py b/cpmpy/transformations/int2bool.py index 161213ca9..5a31efd4c 100644 --- a/cpmpy/transformations/int2bool.py +++ b/cpmpy/transformations/int2bool.py @@ -3,18 +3,18 @@ from typing import List import itertools import math +from ..transformations.flatten_model import get_or_make_var import cpmpy as cp from abc import ABC, abstractmethod -from ..expressions.variables import _BoolVarImpl, _IntVarImpl, boolvar +from ..expressions.variables import _BoolVarImpl, _IntVarImpl from ..expressions.globalconstraints import DirectConstraint from ..expressions.core import Comparison, Operator, BoolVal from ..expressions.core import Expression UNKNOWN_COMPARATOR_ERROR = ValueError("Comparator is not known or should have been simplified by linearize.") -EMPTY_DOMAIN_ERROR = ValueError("Attempted to encode variable with empty domain (which is unsat)") -def int2bool(cpm_lst: List[Expression], ivarmap, encoding="auto"): +def int2bool(cpm_lst: List[Expression], ivarmap, encoding="auto", csemap=None): """Convert integer linear constraints to pseudo-boolean constraints. Requires `linearize` transformation.""" assert encoding in ( "auto", @@ -25,12 +25,12 @@ def int2bool(cpm_lst: List[Expression], ivarmap, encoding="auto"): cpm_out = [] for expr in cpm_lst: - constraints, domain_constraints = _encode_expr(ivarmap, expr, encoding) + constraints, domain_constraints = _encode_expr(ivarmap, expr, encoding, csemap=csemap) cpm_out += domain_constraints + constraints return cpm_out -def _encode_expr(ivarmap, expr, encoding): +def _encode_expr(ivarmap, expr, encoding, csemap=None): """Return encoded constraints and root-level constraints (e.g. domain constraints exactly-one, ..).""" constraints = [] domain_constraints = [] @@ -41,7 +41,7 @@ def _encode_expr(ivarmap, expr, encoding): elif expr.name == "->": # Encode implication recursively p, consequent = expr.args - constraints, domain_constraints = _encode_expr(ivarmap, consequent, encoding) + constraints, domain_constraints = _encode_expr(ivarmap, consequent, encoding, csemap=csemap) return ( [p.implies(constraint) for constraint in constraints], domain_constraints, @@ -52,14 +52,14 @@ def _encode_expr(ivarmap, expr, encoding): if type(lhs) is _BoolVarImpl: return [expr], [] elif type(lhs) is _IntVarImpl: - return _encode_comparison(ivarmap, lhs, expr.name, rhs, encoding) + return _encode_comparison(ivarmap, lhs, expr.name, rhs, encoding, csemap=csemap) elif lhs.name == "sum": if len(lhs.args) == 1: return _encode_expr( - ivarmap, Comparison(expr.name, lhs.args[0], rhs), encoding + ivarmap, Comparison(expr.name, lhs.args[0], rhs), encoding, csemap=csemap ) # even though it seems trivial (to call `_encode_comparison`), using recursion avoids bugs else: - return _encode_linear(ivarmap, lhs.args, expr.name, rhs, encoding) + return _encode_linear(ivarmap, lhs.args, expr.name, rhs, encoding, csemap=csemap) elif lhs.name == "wsum": return _encode_linear( ivarmap, @@ -68,6 +68,7 @@ def _encode_expr(ivarmap, expr, encoding): rhs, encoding, weights=lhs.args[0], + csemap=csemap, ) else: raise NotImplementedError(f"int2bool: comparison with lhs {lhs} not (yet?) supported") @@ -76,7 +77,7 @@ def _encode_expr(ivarmap, expr, encoding): raise NotImplementedError(f"int2bool: non-comparison {expr} not (yet?) supported") -def _encode_int_var(ivarmap, x, encoding): +def _encode_int_var(ivarmap, x, encoding, csemap=None): """Return encoding of integer variable `x` and its domain constraints (if newly encoded).""" if isinstance(x, (BoolVal, _BoolVarImpl)): raise TypeError @@ -84,18 +85,18 @@ def _encode_int_var(ivarmap, x, encoding): return ivarmap[x.name], [] else: if encoding == "direct": - ivarmap[x.name] = IntVarEncDirect(x) + ivarmap[x.name] = IntVarEncDirect(x, csemap=csemap) elif encoding == "order": - ivarmap[x.name] = IntVarEncOrder(x) + ivarmap[x.name] = IntVarEncOrder(x, csemap=csemap) elif encoding == "binary": - ivarmap[x.name] = IntVarEncLog(x) + ivarmap[x.name] = IntVarEncLog(x, csemap=csemap) else: raise NotImplementedError(encoding) - return (ivarmap[x.name], ivarmap[x.name].encode_domain_constraint()) + return (ivarmap[x.name], ivarmap[x.name].encode_domain_constraint(csemap=csemap)) -def _encode_linear(ivarmap, xs, cmp, rhs, encoding, weights=None, check_bounds=True): +def _encode_linear(ivarmap, xs, cmp, rhs, encoding, weights=None, check_bounds=True, csemap=None): """ Convert a linear constraint to a pseudo-boolean constraint. @@ -142,7 +143,7 @@ def _encode_linear(ivarmap, xs, cmp, rhs, encoding, weights=None, check_bounds=T if isinstance(x, _BoolVarImpl): terms += [(w, x)] else: - x_enc, x_cons = _encode_int_var(ivarmap, x, _decide_encoding(x, cmp, encoding)) + x_enc, x_cons = _encode_int_var(ivarmap, x, _decide_encoding(x, cmp, encoding), csemap=csemap) domain_constraints += x_cons # Encode the value of the integer variable as PB expression `(b_1*c_1) + ... + k` new_terms, k = x_enc.encode_term(w) @@ -165,12 +166,12 @@ def _encode_linear(ivarmap, xs, cmp, rhs, encoding, weights=None, check_bounds=T return [Comparison(cmp, lhs, rhs)], domain_constraints -def _encode_comparison(ivarmap, lhs, cmp, rhs, encoding): +def _encode_comparison(ivarmap, lhs, cmp, rhs, encoding, csemap=None): """Encode integer comparison to PB.""" # TODO encode_expr should only use encode linear and check for "comparison" there encoding = _decide_encoding(lhs, cmp, encoding) - lhs_enc, domain_constraints = _encode_int_var(ivarmap, lhs, encoding) - constraints = lhs_enc.encode_comparison(cmp, rhs) + lhs_enc, domain_constraints = _encode_int_var(ivarmap, lhs, encoding, csemap=csemap) + constraints = lhs_enc.encode_comparison(cmp, rhs, csemap=csemap) return constraints, domain_constraints @@ -181,28 +182,25 @@ def _decide_encoding(x, cmp=None, encoding="auto"): elif _dom_size(x) >= 100: # This heuristic is chosen to be small to favour the binary encoding. This is because the PB encoding (e.g. generalized totalizer, ...) of a direct/order encoded PB constraints is quite inefficient unless the AMO/IC side-constraint is taken into account (which is not the case for pysat/pblib/pysdd). return "binary" - elif cmp in (None, "==", "!="): + elif cmp in ("==", "!="): return "direct" # equalities suit the direct encoding - else: # inequalities suit the order encoding + else: # we use the order encoding for inequalities, en when we do not have `cmp` return "order" class IntVarEnc(ABC): """Abstract base class for integer variable encodings.""" - def __init__(self, x, n, name): - """Create encoding of integer variable `x` with `n` Boolean variables named by `name`.""" - if _dom_size(x) == 0: - raise EMPTY_DOMAIN_ERROR - + def __init__(self, x, x_enc, csemap=None): + """Create encoding of integer variable `x` over the given Boolean expressions, `x_enc`. E.g. the direct encoding for `x` should provide `x_enc = ( x == 1, x == 2, ..)`. Any literals created (e.g. b == ( x == 1 )`) are added to the `csemap` if provided.""" self._x = x # the encoded integer variable - - if n == 0: - # `shape=(0,)` raises exception - self._xs = cp.cpm_array([]) - else: - # `x`'s encoding variables - self._xs = boolvar(shape=(n,), name=name) + self._xs = [] + for x_enc_i in x_enc: + lit, _ = get_or_make_var(x_enc_i, csemap=csemap) + # we can remove the definining constraints as the int var will be replaced + lit.name = f"⟦{x_enc_i}⟧" + self._xs.append(lit) + self._xs = cp.cpm_array(self._xs) def vars(self): """Return the Boolean variables in the encoding.""" @@ -223,7 +221,7 @@ def decode(self): return k @abstractmethod - def encode_domain_constraint(self): + def encode_domain_constraint(self, csemap=None): """ Return domain constraints for the encoding. @@ -233,7 +231,7 @@ def encode_domain_constraint(self): pass @abstractmethod - def encode_comparison(self, op, rhs): + def encode_comparison(self, op, rhs, csemap=None): """ Encode a comparison over the variable: self rhs. @@ -266,12 +264,12 @@ class IntVarEncDirect(IntVarEnc): Uses a Boolean 'equality' variable for each value in the domain. """ - def __init__(self, x): + def __init__(self, x, csemap=None): """Create direct encoding of integer variable `x`.""" # Requires |dom(x)| Boolean equality variables - super().__init__(x, _dom_size(x), f"EncDir({x.name})") + super().__init__(x, (x == d for d in _dom(x)), csemap=csemap) - def encode_domain_constraint(self): + def encode_domain_constraint(self, csemap=None): """ Return consistency constraints. @@ -287,7 +285,7 @@ def eq(self, d): else: # don't use `try .. except IndexError` since negative values wrap! return BoolVal(False) - def encode_comparison(self, op, d): + def encode_comparison(self, op, d, csemap=None): if op == "==": # one yes, hence also rest no, if rhs is not in domain will set all to no return [self.eq(d)] @@ -314,11 +312,11 @@ class IntVarEncOrder(IntVarEnc): Uses a Boolean 'inequality' variable for each value in the domain. """ - def __init__(self, x): + def __init__(self, x, csemap=None): """Create order encoding of integer variable `x`.""" - super().__init__(x, _dom_size(x) - 1, f"EncOrd({x.name})") + super().__init__(x, (x >= d for d in itertools.islice(_dom(x), 1, None)), csemap=csemap) - def encode_domain_constraint(self): + def encode_domain_constraint(self, csemap=None): """Return order encoding domain constraint (i.e. encoding variables are sorted in descending order e.g. `111000`).""" if len(self._xs) <= 1: return [] @@ -345,7 +343,7 @@ def geq(self, d): else: return self._xs[self._offset(d)] - def encode_comparison(self, cmp, d): + def encode_comparison(self, cmp, d, csemap=None): if cmp == "==": # x>=d and x=d and x=", "<="): # TODO lexicographic encoding might be more effective, but currently we just use the PB encoding constraint, domain_constraints = _encode_linear( - {self._x.name: self}, [self._x], cmp, d, None, check_bounds=check_bounds + {self._x.name: self}, [self._x], cmp, d, None, check_bounds=check_bounds, csemap=csemap ) assert domain_constraints == [], ( f"{self._x} should have already been encoded, so no domain constraints should be returned" @@ -429,5 +428,9 @@ def encode_term(self, w=1): return [(w * (2**i), b) for i, b in enumerate(self._xs)], w * self._x.lb +def _dom(x): + return iter(range(x.lb, x.ub + 1)) + + def _dom_size(x): return x.ub + 1 - x.lb diff --git a/tests/test_int2bool.py b/tests/test_int2bool.py index 964bbed6e..ba60764b2 100644 --- a/tests/test_int2bool.py +++ b/tests/test_int2bool.py @@ -1,5 +1,6 @@ import pytest +import cpmpy as cp from cpmpy.transformations.flatten_model import flatten_constraint from cpmpy.transformations.get_variables import get_variables from cpmpy.expressions.core import Comparison, Operator, BoolVal @@ -10,7 +11,6 @@ from cpmpy.transformations.int2bool import int2bool from cpmpy.expressions.variables import _IntVarImpl, _BoolVarImpl, intvar, boolvar - # add some small but non-trivial integer variables (i.e. non-zero lower bounds, domain size not a power of two) x = intvar(1, 3, name="x") y = intvar(1, 3, name="y") @@ -145,3 +145,35 @@ def show_int_var(x): SOL_IN: {cons_sols} SOL_OU: {flat_sols} """ + + def test_int2bool_cse_one_var(self): + x = cp.intvar(0, 2, name="x") + slv = cp.solvers.CPM_pindakaas() + slv.encoding = "direct" + # assert str(slv.transform((x == 0) )) == "[(EncDir(x)[0]) + (EncDir (x)[1]) == 1, EncDir(x)[0], ~EncDir(x)[1]]" + assert ( + str(slv.transform((x == 0) | (x == 2))) + == "[(⟦x == 0⟧) or (⟦x == 2⟧), sum([⟦x == 0⟧, ⟦x == 1⟧, ⟦x == 2⟧]) == 1]" + ) + + def test_int2bool_cse_one_var_order(self): + x = cp.intvar(0, 2, name="x") + slv = cp.solvers.CPM_pindakaas() + slv.encoding = "order" + assert ( + str(slv.transform((x >= 1) | (x >= 2))) + == "[(⟦x >= 1⟧) or (⟦x >= 2⟧), sum([1, -1] * (⟦x >= 2⟧, ⟦x >= 1⟧)) <= 0]" + ) + # TODO this could be a CSE improvement? + # assert str(slv.transform((x >= 1) | (x < 2))) == "[(⟦x == 0⟧) or (⟦x == 2⟧), sum([⟦x == 0⟧, ⟦x == 1⟧, ⟦x == 2⟧]) == 1]" + + def test_int2bool_cse_two_vars(self): + slv = cp.solvers.CPM_pindakaas() + x = cp.intvar(0, 2, name="x") + y = cp.intvar(0, 2, name="y") + slv.encoding = "direct" + assert ( + str(slv.transform((x == 0) | (y == 2))) + == "[(⟦x == 0⟧) or (⟦y == 2⟧), sum([⟦x == 0⟧, ⟦x == 1⟧, ⟦x == 2⟧]) == 1, sum([⟦y == 0⟧, ⟦y == 1⟧, ⟦y == 2⟧]) == 1]" + ) +