Source code for wmpy.core.weights

from typing import Any, Collection, Generator, Iterable
import pysmt.operators as op
from pysmt.environment import Environment
from pysmt.fnode import FNode
from pysmt.formula import FormulaManager
from pysmt.oracles import AtomsOracle
from pysmt.rewritings import NNFizer
from pysmt.typing import BOOL
from pysmt.walkers import DagWalker, IdentityDagWalker, TreeWalker, handles

from wmpy.core.utils import is_atom, is_clause, is_cnf, is_literal

# TODO: (maybe) move WeightAtomsFinder, WeightsEvaluator inside Weights (making them private)?


[docs] class Weights: """This class encodes a piecewise weight function. Attributes: weight_func: the weight function in pysmt format env: the pysmt environment atoms_finder: TODO: add a bit of documentation for this evaluator: internal class for function evaluation """ def __init__(self, weight_func: FNode, env: Environment): """Default constructor. Args: weight_func: the pysmt expression representing the weight function env: the pysmt environment """ self.env = env self.weight_func = weight_func self.atoms_finder = WeightAtomsFinder(env=env) self.evaluator = WeightsEvaluator(self)
[docs] def compute_skeleton(self) -> FNode: """Computes the "skeleton", a SMT formula that encodes the structure of the weight function. Conjoining the skeleton with the support formula can be advantageous when using partial enumeration. Returns: A pysmt formula that encodes the structure of the weight. """ return WeightConverterSkeleton(env=self.env).convert(self.weight_func)
[docs] def get_atoms(self) -> Collection[FNode]: """Returns the atoms contained in the (conditions of the) weight expressions.""" atoms = self.atoms_finder.get_atoms(self.weight_func) return atoms if atoms is not None else frozenset([])
[docs] def weight_from_assignment(self, assignment: dict[FNode, bool]) -> FNode: """Evaluates the weight function given a total truth assignment to its conditions. Args: assignment: the truth assignment as returned by an Enumerator Returns: A pysmt term that correspond to unconditional weight function obtained by assigning a truth value to the conditions. Raises: ValueError if the TA is not total. """ return self.evaluator.evaluate(assignment)
def __str__(self) -> str: return ( "Weight {{" "\t{weight}\n" "}}".format( weight=self.weight_func.serialize(), ) )
[docs] class WeightsEvaluator(TreeWalker): """This internal class implements the weight evaluation given a truth assignment to its conditions.""" def __init__(self, weights: Weights): super().__init__(weights.env) self.mgr: FormulaManager = self.env.formula_manager self.simplifier = self.env.simplifier self.substituter = self.env.substituter self.weight_node: FNode = weights.weight_func self.assignment: dict[FNode, FNode] = {} self.result: list[FNode] = [] # stack to store the results of the evaluation
[docs] def evaluate(self, assignment: dict[FNode, bool]) -> FNode: """Evaluates the weight function given a total TA to its conditions. Returns: The simplified expression in pysmt format. Raises: ValueError if the TA is not total. """ self.result.clear() self.assignment = {atom: self.mgr.Bool(v) for atom, v in assignment.items()} self.walk(self.weight_node) assert len(self.result) == 1, f"Expected a single result, got {self.result}" return self.result.pop()
[docs] def walk_ite(self, formula: FNode) -> Generator[FNode, None, None]: cond, then, _else = formula.args() value = self._evaluate_condition(cond) yield then if value else _else # recursion on the branch that is True
@handles(op.SYMBOL) @handles(op.CONSTANTS) def walk_leaf(self, formula: FNode) -> None: self.result.append(formula) @handles(op.IRA_OPERATORS) def walk_operator(self, formula: FNode) -> Generator[FNode, None, None]: for arg in reversed(formula.args()): yield arg # recurse on children new_children = (self.result.pop() for _ in formula.args()) self.result.append( self.mgr.create_node( node_type=formula.node_type(), args=tuple(new_children) ) ) def _evaluate_condition(self, condition: FNode) -> bool: val = self.simplifier.simplify( self.substituter.substitute(condition, self.assignment) ) if not val.is_bool_constant(): msg = ( "Weight condition " + self.env.serializer.serialize(condition) + "\n\n cannot be evaluated with assignment " + "\n".join([str((x, v)) for x, v in self.assignment.items()]) + "\n\n simplified into " + self.env.serializer.serialize(condition) ) raise ValueError(msg) return val.constant_value()
[docs] class WeightAtomsFinder(AtomsOracle): """TODO"""
[docs] def walk_ite( self, formula: FNode, args: list[frozenset[FNode]], **kwargs: Any ) -> frozenset[FNode]: return frozenset(x for a in args if a is not None for x in a)
@handles(op.IRA_OPERATORS) def walk_theory_op( # pyright: ignore self, formula: FNode, args: list[frozenset[FNode]], **kwargs: Any ) -> frozenset[FNode]: return frozenset(x for a in args if a is not None for x in a)
[docs] class WeightConverterSkeleton(TreeWalker): """This internal class implements the conversion of a weight function into a weight skeleton, as described in "Enhancing SMT-based Weighted Model Integration by structure awareness" (Spallitta et al., 2024). """ def __init__(self, env: Environment): super().__init__(env) self.mgr = self.env.formula_manager self.cond_labels: set[FNode] = set() self.cnfizer = PolarityCNFizer(env=self.env) self.branch_condition: list[FNode] = [] # clause as a list of FNodes self.clauses: list[FNode] = ( [] ) # list of clauses, each clause is an Or of FNodes def new_cond_label(self) -> FNode: b = self.mgr.FreshSymbol(typename=BOOL, template="CNDB%s") self.cond_labels.add(b) return b def convert(self, weight_func: FNode) -> FNode: self.clauses.clear() self.walk(weight_func) return self.mgr.And(self.clauses) @handles(op.SYMBOL) @handles(op.CONSTANTS) def walk_no_conditions(self, formula: FNode) -> None: return @handles(op.IRA_OPERATORS) def walk_operator(self, formula: FNode) -> Generator[FNode, None, None]: for arg in formula.args(): yield arg
[docs] def walk_ite(self, formula: FNode) -> Generator[FNode, None, None]: phi: FNode left: FNode right: FNode phi, left, right = formula.args() if is_atom(phi): # Trick to force the splitting on phi on the current branch represented by branch_condition # (here branch_condition is Not(conds)). # In the original algorithm, we would have added: # (conds -> (phi v not phi)). # This would require a custom MathSAT version to avoid the simplification of the valid clause. # # Here, instead, we add: # (conds -> exists k.CNF(phi <-> k)) # which is equivalent to the above approach, but does not get simplified and does not require # using a custom MathSAT version. # (k is implicitly existentially quantified since we do not enumerate on it) k = self.new_cond_label() self.clauses.append( self.mgr.Or(*self.branch_condition, self.mgr.Not(k), phi) ) self.clauses.append( self.mgr.Or(*self.branch_condition, k, self.mgr.Not(phi)) ) self.branch_condition.append(self.mgr.Not(phi)) yield left # recursion on the left branch self.branch_condition.pop() self.branch_condition.append(phi) yield right # recursion on the right branch self.branch_condition.pop() else: b = self.new_cond_label() # Here we are not adding the clause # (branch_condition -> (b v not b)) # since it is subsumed by the CNF clauses of # (branch_condition -> exists b.CNF(b <-> phi)) # add (conds & b) -> CNF(phi) self.branch_condition.append(self.mgr.Not(b)) for clause in self.cnfizer.convert(phi): self.clauses.append(self.mgr.Or(*self.branch_condition, *clause)) yield left # recursion on the left branch self.branch_condition.pop() # add (conds & not b) -> CNF(not phi) self.branch_condition.append(b) for clause in self.cnfizer.convert(self.mgr.Not(phi)): self.clauses.append(self.mgr.Or(*self.branch_condition, *clause)) yield right # recursion on the right branch self.branch_condition.pop()
[docs] class CNFPreprocessor(IdentityDagWalker): """Converts nested ORs and ANDs into flat lists of ORs and ANDs, and Implies into Or.""" def __init__(self, env: Environment): super().__init__(env) self.nnfizer = NNFizer(env) def walk(self, formula: FNode, **kwargs: Any) -> FNode: formula = self.nnfizer.convert(formula) return super().walk(formula, **kwargs)
[docs] def walk_or(self, formula: FNode, args: list[FNode], **kwargs: Any) -> FNode: children = [] for arg in args: if arg.is_true(): return self.mgr.Bool(True) elif arg.is_false(): continue elif arg.is_or(): children.extend(arg.args()) elif arg.is_not() and arg.arg(0).is_and(): children.extend(map(self.mgr.Not, arg.arg(0).args())) else: children.append(arg) return self.mgr.Or(children)
[docs] def walk_and(self, formula: FNode, args: list[FNode], **kwargs: Any) -> FNode: children = [] for arg in args: if arg.is_false(): return self.mgr.Bool(False) elif arg.is_true(): continue elif arg.is_and(): children.extend(arg.args()) elif arg.is_not() and arg.arg(0).is_or(): children.extend(map(self.mgr.Not, arg.arg(0).args())) else: children.append(arg) return self.mgr.And(children)
[docs] class PolarityCNFizer(DagWalker): """Implements the Plaisted&Greenbaum CNF conversion algorithm.""" CNF = list[list[FNode]] def __init__(self, env: Environment): super().__init__(env) self.mgr = self.env.formula_manager self.preprocessor = CNFPreprocessor(env=self.env) self._introduced_variables: dict[FNode, FNode] = {} def _get_key(self, formula: FNode, **kwargs: Any) -> FNode: return formula def _key_var(self, formula: FNode) -> FNode: if formula in self._introduced_variables: res = self._introduced_variables[formula] else: res = self.mgr.FreshSymbol(typename=BOOL, template="CNFB%s") self._introduced_variables[formula] = res return res def _neg(self, formula: FNode) -> FNode: if formula.is_not(): return formula.arg(0) else: return self.mgr.Not(formula)
[docs] def convert(self, formula: FNode) -> frozenset[frozenset[FNode]]: """Converts formula into an equisatisfiable CNF. Returns a set of clauses, i.e. a set of sets of literals. """ def literals_in_clause(clause: FNode) -> Iterable[FNode]: if is_literal(clause): yield clause else: yield from clause.args() def literals_in_cnf(cnf: FNode) -> Iterable[Iterable[FNode]]: if is_clause(cnf): yield literals_in_clause(cnf) else: yield from (literals_in_clause(clause) for clause in cnf.args()) formula = self.preprocessor.walk(formula) if is_cnf(formula): return frozenset(map(frozenset, literals_in_cnf(formula))) cnf: list[list[FNode]] = list() tl: FNode = self.walk(formula, cnf=cnf) res = [] for clause in cnf: if len(clause) == 0: return frozenset(frozenset()) simp: list[FNode] = [] for lit in clause: if lit is tl or lit.is_true(): # Prune clauses that are trivially TRUE # and clauses containing the top level label simp = [] break elif not lit.is_false() and lit is not self._neg(tl): # Prune FALSE literals simp.append(lit) if simp: res.append(frozenset(simp)) return frozenset(res)
[docs] def walk_not( self, formula: FNode, args: list[FNode], cnf: CNF, **kwargs: Any ) -> FNode: a = args[0] if a.is_true(): return self.mgr.Bool(False) elif a.is_false(): return self.mgr.Bool(True) else: return self._neg(a)
[docs] def walk_and( self, formula: FNode, args: list[FNode], cnf: CNF, **kwargs: Any ) -> FNode: if len(args) == 1: return args[0] k = self._key_var(formula) for a in args: cnf.append([a, self._neg(k)]) return k
[docs] def walk_or( self, formula: FNode, args: list[FNode], cnf: CNF, **kwargs: Any ) -> FNode: if len(args) == 1: return args[0] k = self._key_var(formula) cnf.append([self._neg(k)] + args) return k
[docs] def walk_iff( self, formula: FNode, args: list[FNode], cnf: CNF, **kwargs: Any ) -> FNode: left, right = args if left == right: return self.mgr.Bool(True) k = self._key_var(formula) cnf.append([self._neg(k), self._neg(left), right]) cnf.append([self._neg(k), left, self._neg(right)]) cnf.append([k, left, right]) cnf.append([k, self._neg(left), self._neg(right)]) return k
[docs] @handles(op.SYMBOL) @handles(op.CONSTANTS) @handles(op.RELATIONS) @handles(op.THEORY_OPERATORS) def walk_identity(self, formula: FNode, **kwargs: Any) -> FNode: return formula