import z3
import abc
import collections
import logging
logger = logging.getLogger(__name__)
[docs]class SolutionNotFoundError(Exception):
def __init__(self, message, labels=None):
self.message = message
self.labels = labels
super().__init__(self.message)
[docs]class AbstractSolver(abc.ABC):
[docs] @abc.abstractmethod
def append(self, formula, label=None):
'''
Append formula to checker.
Note: Please use bbox variables to create formulae
Otherwise you will need to manage types
yourself
'''
pass
[docs] @abc.abstractmethod
def solve(self):
'''
Solve & return solutions
If no solution is found, raise SolutionNotFoundError
'''
pass
[docs] @abc.abstractmethod
def label(self, object):
'''
Generate label that can be used for
back-annotation
Note: Return None if solver
doesn't support back-annotation
'''
pass
[docs] @abc.abstractmethod
def annotate(self, formulae, label):
'''
Yield formulae annotated with label
Note: Input 'formulae' is iterable.
Note: MUST return an iterable object
Note: Return original iterable if solver
doesn't support back-annotation
'''
pass
[docs] @abc.abstractmethod
def checkpoint(self):
'''
Checkpoint current state of solver
Note: We assume incremental solving here
May need to revisit if we have to
rebuild solution from scratch
'''
pass
[docs] @abc.abstractmethod
def revert(self):
'''
Revert to last checkpoint
Note: We assume incremental solving here
May need to revisit if we have to
rebuild solution from scratch
'''
pass
[docs] @abc.abstractmethod
def bbox_vars(self, name):
'''
Generate a single namedtuple containing
appropriate checker variables for
placement constraints
'''
pass
[docs] def iter_bbox_vars(self, names):
'''
Helper utility to generate multiple bbox variables
The output should be an iterator that allows you
to loop over bboxes (use `yield` when possible)
'''
for name in names:
yield self.bbox_vars(name)
[docs] @abc.abstractmethod
def And(self, *expressions):
'''
Logical `And` of all arguments
Note: arguments are assumed to be
boolean expressions
'''
pass
[docs] @abc.abstractmethod
def Or(self, *expressions):
'''
Logical `Or` of all arguments
Note: arguments are assumed to be
boolean expressions
'''
pass
[docs] @abc.abstractmethod
def Not(self, expr):
'''
Logical `Not` of argument
Note: argument is assumed to be
a boolean expression
'''
pass
[docs] @abc.abstractmethod
def Implies(self, expr1, expr2):
'''
expr1 => expr2
Note: both arguments are assumed
to be boolean expressions
'''
pass
[docs] @abc.abstractmethod
def cast(expr, type_):
'''
cast `expr` to `type_`
Note: Use with care. Not all
engines support all types
'''
pass
[docs] @abc.abstractmethod
def Abs(self, expr):
'''
Absolute value of expression
Note: argument is assumed to be
arithmetic expression
'''
pass
AnnotatedFormula = collections.namedtuple('AnnotatedFormula', ['formula', 'label'])
[docs]class Z3Checker(AbstractSolver):
def __init__(self):
self._solver = z3.Solver()
self._solver.set(unsat_core=True)
[docs] def annotate(self, formulae, label):
yield AnnotatedFormula(
formula=self.And(
*formulae
) if len(formulae) > 1 else formulae[0],
label=label)
[docs] def append(self, formula):
if isinstance(formula, AnnotatedFormula):
self._solver.assert_and_track(formula.formula, formula.label)
else:
self._solver.add(formula)
[docs] def solve(self):
r = self._solver.check()
if r == z3.unsat:
z3.set_option(max_depth=10000, max_args=100, max_lines=10000)
logger.debug(f"Unsat encountered: {self._solver}")
raise SolutionNotFoundError(
message='No satisfying solution could be found. Please review constraints.',
labels=self._solver.unsat_core())
[docs] def checkpoint(self):
self._solver.push()
[docs] def revert(self):
self._solver.pop()
[docs] def bbox_vars(self, name):
# generate new bbox
return self._generate_var(
'Bbox',
llx=f'{name}_llx',
lly=f'{name}_lly',
urx=f'{name}_urx',
ury=f'{name}_ury')
[docs] def label(self, object):
# Z3 throws 'index out of bounds' error
# if more than 9 digits are used
return z3.Bool(
hash(repr(object)) % 10**9
)
[docs] @staticmethod
def Or(*expressions):
return z3.Or(*expressions)
[docs] @staticmethod
def And(*expressions):
return z3.And(*expressions)
[docs] @staticmethod
def Not(expr):
return z3.Not(expr)
[docs] @staticmethod
def Abs(expr):
return z3.If(expr >= 0, expr, expr * -1)
[docs] @staticmethod
def Implies(expr1, expr2):
return z3.Implies(expr1, expr2)
[docs] @staticmethod
def cast(expr, type_):
if type_ == float:
return z3.ToReal(expr)
else:
raise NotImplementedError
@staticmethod
def _generate_var(name, **fields):
if fields:
return collections.namedtuple(
name,
fields.keys(),
)(*z3.Ints(' '.join(fields.values())))
else:
return z3.Int(name)