Source code for wmpy.enumeration.total

from typing import Iterable, Optional, cast

from pysmt.environment import Environment, get_env
from pysmt.fnode import FNode

from wmpy.core.weights import Weights


[docs] class TotalEnumerator: """This class implements a baseline total enumerator using the Z3 SMT solver.""" def __init__( self, support: FNode, weight: Optional[FNode] = None, env: Optional[Environment] = None, ) -> None: """Default constructor. Args: weights: the weight function as a pysmt term support: the support of the weight function (a pysmt formula) env: the pysmt environment (optional) """ self.support = support if env is not None: self.env = env else: self.env = cast(Environment, get_env()) if weight is None: weight = self.env.formula_manager.Real(1) # Default weight is 1 self.weights = Weights(weight, self.env)
[docs] def enumerate(self, query: FNode) -> Iterable[tuple[dict[FNode, bool], int]]: """Enumerates (possibly partial) truth assignments for the given formula. Since the truth assignments (TA) are always total, the number of unassigned Boolean variables is always 0. Args: query: the query (a pysmt formula) Returns: An iterable of tuples <TA, 0> where: - TA is a dictionary {pysmt_atom : bool} representing (partial) truth assignment """ mgr = self.env.formula_manager # conjoin query and support formula = mgr.And(query, self.support) # sort the different atoms atoms = self.env.ao.get_atoms(formula) | self.weights.get_atoms() with self.env.factory.Solver(name="z3") as smt_solver: smt_solver.add_assertion(formula) while smt_solver.solve(): model = {} blocking_clause = [] for a in atoms: literal = ( a if smt_solver.get_value(a).constant_value() else mgr.Not(a) ) model[a] = not literal.is_not() blocking_clause.append(literal) smt_solver.add_assertion(mgr.Not(mgr.And(*blocking_clause))) yield model, 0