Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions cpmpy/solvers/pindakaas.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,17 @@ def solver_var(self, cpm_var):
raise TypeError

def transform(self, cpm_expr):
cpm_cons = toplevel_list(cpm_expr)
cpm_cons = [cpm_expr]

# eagerly encode integer variables to maximize CSE
for x in get_variables(cpm_expr):
if (not isinstance(x, _BoolVarImpl)) and isinstance(x, _IntVarImpl):
_, cons = _encode_int_var(
self.ivarmap, x, _decide_encoding(x, encoding=self.encoding), csemap=self._csemap
)
cpm_cons += cons

cpm_cons = toplevel_list(cpm_cons)
cpm_cons = no_partial_functions(cpm_cons)
cpm_cons = decompose_in_tree(cpm_cons, csemap=self._csemap)
cpm_cons = simplify_boolean(cpm_cons)
Expand All @@ -240,7 +250,7 @@ def transform(self, cpm_expr):
cpm_cons = linearize_constraint(
cpm_cons, supported=frozenset({"sum", "wsum", "and", "or"}), csemap=self._csemap
)
cpm_cons = int2bool(cpm_cons, self.ivarmap, encoding=self.encoding)
cpm_cons = int2bool(cpm_cons, self.ivarmap, encoding=self.encoding, csemap=self._csemap)
return cpm_cons

def add(self, cpm_expr_orig):
Expand Down
101 changes: 52 additions & 49 deletions cpmpy/transformations/int2bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
from typing import List
import itertools
import math
from ..transformations.flatten_model import get_or_make_var
import cpmpy as cp
from abc import ABC, abstractmethod
from ..expressions.variables import _BoolVarImpl, _IntVarImpl, boolvar
from ..expressions.variables import _BoolVarImpl, _IntVarImpl
from ..expressions.globalconstraints import DirectConstraint
from ..expressions.core import Comparison, Operator, BoolVal
from ..expressions.core import Expression

UNKNOWN_COMPARATOR_ERROR = ValueError("Comparator is not known or should have been simplified by linearize.")
EMPTY_DOMAIN_ERROR = ValueError("Attempted to encode variable with empty domain (which is unsat)")


def int2bool(cpm_lst: List[Expression], ivarmap, encoding="auto"):
def int2bool(cpm_lst: List[Expression], ivarmap, encoding="auto", csemap=None):
"""Convert integer linear constraints to pseudo-boolean constraints. Requires `linearize` transformation."""
assert encoding in (
"auto",
Expand All @@ -25,12 +25,12 @@ def int2bool(cpm_lst: List[Expression], ivarmap, encoding="auto"):

cpm_out = []
for expr in cpm_lst:
constraints, domain_constraints = _encode_expr(ivarmap, expr, encoding)
constraints, domain_constraints = _encode_expr(ivarmap, expr, encoding, csemap=csemap)
cpm_out += domain_constraints + constraints
return cpm_out


def _encode_expr(ivarmap, expr, encoding):
def _encode_expr(ivarmap, expr, encoding, csemap=None):
"""Return encoded constraints and root-level constraints (e.g. domain constraints exactly-one, ..)."""
constraints = []
domain_constraints = []
Expand All @@ -41,7 +41,7 @@ def _encode_expr(ivarmap, expr, encoding):
elif expr.name == "->":
# Encode implication recursively
p, consequent = expr.args
constraints, domain_constraints = _encode_expr(ivarmap, consequent, encoding)
constraints, domain_constraints = _encode_expr(ivarmap, consequent, encoding, csemap=csemap)
return (
[p.implies(constraint) for constraint in constraints],
domain_constraints,
Expand All @@ -52,14 +52,14 @@ def _encode_expr(ivarmap, expr, encoding):
if type(lhs) is _BoolVarImpl:
return [expr], []
elif type(lhs) is _IntVarImpl:
return _encode_comparison(ivarmap, lhs, expr.name, rhs, encoding)
return _encode_comparison(ivarmap, lhs, expr.name, rhs, encoding, csemap=csemap)
elif lhs.name == "sum":
if len(lhs.args) == 1:
return _encode_expr(
ivarmap, Comparison(expr.name, lhs.args[0], rhs), encoding
ivarmap, Comparison(expr.name, lhs.args[0], rhs), encoding, csemap=csemap
) # even though it seems trivial (to call `_encode_comparison`), using recursion avoids bugs
else:
return _encode_linear(ivarmap, lhs.args, expr.name, rhs, encoding)
return _encode_linear(ivarmap, lhs.args, expr.name, rhs, encoding, csemap=csemap)
elif lhs.name == "wsum":
return _encode_linear(
ivarmap,
Expand All @@ -68,6 +68,7 @@ def _encode_expr(ivarmap, expr, encoding):
rhs,
encoding,
weights=lhs.args[0],
csemap=csemap,
)
else:
raise NotImplementedError(f"int2bool: comparison with lhs {lhs} not (yet?) supported")
Expand All @@ -76,26 +77,26 @@ def _encode_expr(ivarmap, expr, encoding):
raise NotImplementedError(f"int2bool: non-comparison {expr} not (yet?) supported")


def _encode_int_var(ivarmap, x, encoding):
def _encode_int_var(ivarmap, x, encoding, csemap=None):
"""Return encoding of integer variable `x` and its domain constraints (if newly encoded)."""
if isinstance(x, (BoolVal, _BoolVarImpl)):
raise TypeError
elif x.name in ivarmap: # already encoded
return ivarmap[x.name], []
else:
if encoding == "direct":
ivarmap[x.name] = IntVarEncDirect(x)
ivarmap[x.name] = IntVarEncDirect(x, csemap=csemap)
elif encoding == "order":
ivarmap[x.name] = IntVarEncOrder(x)
ivarmap[x.name] = IntVarEncOrder(x, csemap=csemap)
elif encoding == "binary":
ivarmap[x.name] = IntVarEncLog(x)
ivarmap[x.name] = IntVarEncLog(x, csemap=csemap)
else:
raise NotImplementedError(encoding)

return (ivarmap[x.name], ivarmap[x.name].encode_domain_constraint())
return (ivarmap[x.name], ivarmap[x.name].encode_domain_constraint(csemap=csemap))


def _encode_linear(ivarmap, xs, cmp, rhs, encoding, weights=None, check_bounds=True):
def _encode_linear(ivarmap, xs, cmp, rhs, encoding, weights=None, check_bounds=True, csemap=None):
"""
Convert a linear constraint to a pseudo-boolean constraint.

Expand Down Expand Up @@ -142,7 +143,7 @@ def _encode_linear(ivarmap, xs, cmp, rhs, encoding, weights=None, check_bounds=T
if isinstance(x, _BoolVarImpl):
terms += [(w, x)]
else:
x_enc, x_cons = _encode_int_var(ivarmap, x, _decide_encoding(x, cmp, encoding))
x_enc, x_cons = _encode_int_var(ivarmap, x, _decide_encoding(x, cmp, encoding), csemap=csemap)
domain_constraints += x_cons
# Encode the value of the integer variable as PB expression `(b_1*c_1) + ... + k`
new_terms, k = x_enc.encode_term(w)
Expand All @@ -165,12 +166,12 @@ def _encode_linear(ivarmap, xs, cmp, rhs, encoding, weights=None, check_bounds=T
return [Comparison(cmp, lhs, rhs)], domain_constraints


def _encode_comparison(ivarmap, lhs, cmp, rhs, encoding):
def _encode_comparison(ivarmap, lhs, cmp, rhs, encoding, csemap=None):
"""Encode integer comparison to PB."""
# TODO encode_expr should only use encode linear and check for "comparison" there
encoding = _decide_encoding(lhs, cmp, encoding)
lhs_enc, domain_constraints = _encode_int_var(ivarmap, lhs, encoding)
constraints = lhs_enc.encode_comparison(cmp, rhs)
lhs_enc, domain_constraints = _encode_int_var(ivarmap, lhs, encoding, csemap=csemap)
constraints = lhs_enc.encode_comparison(cmp, rhs, csemap=csemap)
return constraints, domain_constraints


Expand All @@ -181,28 +182,25 @@ def _decide_encoding(x, cmp=None, encoding="auto"):
elif _dom_size(x) >= 100:
# This heuristic is chosen to be small to favour the binary encoding. This is because the PB encoding (e.g. generalized totalizer, ...) of a direct/order encoded PB constraints is quite inefficient unless the AMO/IC side-constraint is taken into account (which is not the case for pysat/pblib/pysdd).
return "binary"
elif cmp in (None, "==", "!="):
elif cmp in ("==", "!="):
return "direct" # equalities suit the direct encoding
else: # inequalities suit the order encoding
else: # we use the order encoding for inequalities, en when we do not have `cmp`
return "order"


class IntVarEnc(ABC):
"""Abstract base class for integer variable encodings."""

def __init__(self, x, n, name):
"""Create encoding of integer variable `x` with `n` Boolean variables named by `name`."""
if _dom_size(x) == 0:
raise EMPTY_DOMAIN_ERROR

def __init__(self, x, x_enc, csemap=None):
"""Create encoding of integer variable `x` over the given Boolean expressions, `x_enc`. E.g. the direct encoding for `x` should provide `x_enc = ( x == 1, x == 2, ..)`. Any literals created (e.g. b == ( x == 1 )`) are added to the `csemap` if provided."""
self._x = x # the encoded integer variable

if n == 0:
# `shape=(0,)` raises exception
self._xs = cp.cpm_array([])
else:
# `x`'s encoding variables
self._xs = boolvar(shape=(n,), name=name)
self._xs = []
for x_enc_i in x_enc:
lit, _ = get_or_make_var(x_enc_i, csemap=csemap)
# we can remove the definining constraints as the int var will be replaced
lit.name = f"⟦{x_enc_i}⟧"
self._xs.append(lit)
self._xs = cp.cpm_array(self._xs)

def vars(self):
"""Return the Boolean variables in the encoding."""
Expand All @@ -223,7 +221,7 @@ def decode(self):
return k

@abstractmethod
def encode_domain_constraint(self):
def encode_domain_constraint(self, csemap=None):
"""
Return domain constraints for the encoding.

Expand All @@ -233,7 +231,7 @@ def encode_domain_constraint(self):
pass

@abstractmethod
def encode_comparison(self, op, rhs):
def encode_comparison(self, op, rhs, csemap=None):
"""
Encode a comparison over the variable: self <op> rhs.

Expand Down Expand Up @@ -266,12 +264,12 @@ class IntVarEncDirect(IntVarEnc):
Uses a Boolean 'equality' variable for each value in the domain.
"""

def __init__(self, x):
def __init__(self, x, csemap=None):
"""Create direct encoding of integer variable `x`."""
# Requires |dom(x)| Boolean equality variables
super().__init__(x, _dom_size(x), f"EncDir({x.name})")
super().__init__(x, (x == d for d in _dom(x)), csemap=csemap)

def encode_domain_constraint(self):
def encode_domain_constraint(self, csemap=None):
"""
Return consistency constraints.

Expand All @@ -287,7 +285,7 @@ def eq(self, d):
else: # don't use `try .. except IndexError` since negative values wrap!
return BoolVal(False)

def encode_comparison(self, op, d):
def encode_comparison(self, op, d, csemap=None):
if op == "==":
# one yes, hence also rest no, if rhs is not in domain will set all to no
return [self.eq(d)]
Expand All @@ -314,11 +312,11 @@ class IntVarEncOrder(IntVarEnc):
Uses a Boolean 'inequality' variable for each value in the domain.
"""

def __init__(self, x):
def __init__(self, x, csemap=None):
"""Create order encoding of integer variable `x`."""
super().__init__(x, _dom_size(x) - 1, f"EncOrd({x.name})")
super().__init__(x, (x >= d for d in itertools.islice(_dom(x), 1, None)), csemap=csemap)

def encode_domain_constraint(self):
def encode_domain_constraint(self, csemap=None):
"""Return order encoding domain constraint (i.e. encoding variables are sorted in descending order e.g. `111000`)."""
if len(self._xs) <= 1:
return []
Expand All @@ -345,7 +343,7 @@ def geq(self, d):
else:
return self._xs[self._offset(d)]

def encode_comparison(self, cmp, d):
def encode_comparison(self, cmp, d, csemap=None):
if cmp == "==": # x>=d and x<d+1
return self.eq(d)
elif cmp == "!=": # x<d or x>=d+1
Expand All @@ -370,15 +368,16 @@ class IntVarEncLog(IntVarEnc):
Uses a Boolean 'bit' variable to represent `x` using the unsigned binary representation offset by its lower bound (e.g. for `x in 5..8`, the assignment `00` maps to `x=5`, and `11` to `x=8`). In other words, it is `k`-offset binary encoding where `k=x.lb`.
"""

def __init__(self, x):
def __init__(self, x, csemap=None):
"""Create binary encoding of integer variable `x`."""
bits = math.ceil(math.log2(_dom_size(x)))
super().__init__(x, bits, f"EncBin({x.name})")
super().__init__(x, (cp.boolvar(name=f"bit({x},{k})") for k in range(bits)), csemap=csemap)
# TODO possibly...: super().__init__(x, ((( ((x - x.lb) ** k) % 2) == 0) for k in range(bits)), csemap=csemap)

def encode_domain_constraint(self):
def encode_domain_constraint(self, csemap=None):
"""Return binary encoding domain constraint (i.e. upper bound is respected with `self._x<=self._x.ub`. The lower bound is automatically enforced by offset binary which maps `000.. = self._x.lb`)."""
# encode directly to avoid bounds check for this seemingly tautological constraint
return self.encode_comparison("<=", self._x.ub, check_bounds=False)
return self.encode_comparison("<=", self._x.ub, check_bounds=False, csemap=csemap)

def _to_little_endian_offset_binary(self, d):
"""Return offset binary representation of `d` as Booleans in order of increasing significance ("little-endian").
Expand Down Expand Up @@ -407,7 +406,7 @@ def eq(self, d):
else: # don't use try IndexError since negative values wrap
return [BoolVal(False)]

def encode_comparison(self, cmp, d, check_bounds=True):
def encode_comparison(self, cmp, d, check_bounds=True, csemap=None):
if cmp == "==": # x>=d and x<d+1
return self.eq(d)
elif cmp == "!=": # x<d or x>=d+1
Expand All @@ -416,7 +415,7 @@ def encode_comparison(self, cmp, d, check_bounds=True):
elif cmp in (">=", "<="):
# TODO lexicographic encoding might be more effective, but currently we just use the PB encoding
constraint, domain_constraints = _encode_linear(
{self._x.name: self}, [self._x], cmp, d, None, check_bounds=check_bounds
{self._x.name: self}, [self._x], cmp, d, None, check_bounds=check_bounds, csemap=csemap
)
assert domain_constraints == [], (
f"{self._x} should have already been encoded, so no domain constraints should be returned"
Expand All @@ -429,5 +428,9 @@ def encode_term(self, w=1):
return [(w * (2**i), b) for i, b in enumerate(self._xs)], w * self._x.lb


def _dom(x):
return iter(range(x.lb, x.ub + 1))


def _dom_size(x):
return x.ub + 1 - x.lb
34 changes: 33 additions & 1 deletion tests/test_int2bool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

import cpmpy as cp
from cpmpy.transformations.flatten_model import flatten_constraint
from cpmpy.transformations.get_variables import get_variables
from cpmpy.expressions.core import Comparison, Operator, BoolVal
Expand All @@ -10,7 +11,6 @@
from cpmpy.transformations.int2bool import int2bool
from cpmpy.expressions.variables import _IntVarImpl, _BoolVarImpl, intvar, boolvar


# add some small but non-trivial integer variables (i.e. non-zero lower bounds, domain size not a power of two)
x = intvar(1, 3, name="x")
y = intvar(1, 3, name="y")
Expand Down Expand Up @@ -145,3 +145,35 @@ def show_int_var(x):
SOL_IN: {cons_sols}
SOL_OU: {flat_sols}
"""

def test_int2bool_cse_one_var(self):
x = cp.intvar(0, 2, name="x")
slv = cp.solvers.CPM_pindakaas()
slv.encoding = "direct"
# assert str(slv.transform((x == 0) )) == "[(EncDir(x)[0]) + (EncDir (x)[1]) == 1, EncDir(x)[0], ~EncDir(x)[1]]"
assert (
str(slv.transform((x == 0) | (x == 2)))
== "[(⟦x == 0⟧) or (⟦x == 2⟧), sum([⟦x == 0⟧, ⟦x == 1⟧, ⟦x == 2⟧]) == 1]"
)

def test_int2bool_cse_one_var_order(self):
x = cp.intvar(0, 2, name="x")
slv = cp.solvers.CPM_pindakaas()
slv.encoding = "order"
assert (
str(slv.transform((x >= 1) | (x >= 2)))
== "[(⟦x >= 1⟧) or (⟦x >= 2⟧), sum([1, -1] * (⟦x >= 2⟧, ⟦x >= 1⟧)) <= 0]"
)
# TODO this could be a CSE improvement?
# assert str(slv.transform((x >= 1) | (x < 2))) == "[(⟦x == 0⟧) or (⟦x == 2⟧), sum([⟦x == 0⟧, ⟦x == 1⟧, ⟦x == 2⟧]) == 1]"

def test_int2bool_cse_two_vars(self):
slv = cp.solvers.CPM_pindakaas()
x = cp.intvar(0, 2, name="x")
y = cp.intvar(0, 2, name="y")
slv.encoding = "direct"
assert (
str(slv.transform((x == 0) | (y == 2)))
== "[(⟦x == 0⟧) or (⟦y == 2⟧), sum([⟦x == 0⟧, ⟦x == 1⟧, ⟦x == 2⟧]) == 1, sum([⟦y == 0⟧, ⟦y == 1⟧, ⟦y == 2⟧]) == 1]"
)

Loading