from typing import Collection
import networkx as nx
from pysmt.fnode import FNode
from pysmt.typing import REAL, BOOL
from .polynomial import Polynomial, PolynomialParser
from .polytope import Polytope
from wmpy.enumeration.enumerator import Enumerator
[docs]
class AssignmentConverter:
"""This class is responsible of converting the pysmt assignments returned by an enumerator into pairs <Polytope, Polynomial>."""
def __init__(self, enumerator: "Enumerator", domain: Collection[FNode]) -> None:
"""Default constructor.
Args:
enumerator: the enumerator instance
domain: list of real variables in pysmt format
"""
self.enumerator = enumerator
self.poly_parser = PolynomialParser(domain, self.enumerator.env)
[docs]
def convert(
self, truth_assignment: dict[FNode, bool]
) -> tuple[Polytope, Polynomial]:
"""Converts a truth assignment (as returned by an Enumerator)
into a <Polytope, Polynomial> pair.
Args:
truth_assignment: mapping pysmt atoms to bool
Returns:
A convex integration problem as a pair of instances of Polytope and Polynomial.
The two represent the convex integration bounds and integrand respectively.
"""
mgr = self.enumerator.env.formula_manager
uncond_weight = self.enumerator.weights.weight_from_assignment(truth_assignment)
# build a dependency graph of the alias substitutions
# handle non-constant and constant definitions separately
Gsub: nx.DiGraph = nx.DiGraph()
constants = {}
aliases: dict[FNode, FNode] = {}
inequalities = []
for atom, truth_value in truth_assignment.items():
if atom.is_le() or atom.is_lt():
inequalities.append(atom if truth_value else mgr.Not(atom))
elif atom.is_equals() and truth_value:
left, right = atom.args()
if left.is_symbol(REAL):
alias, expr = left, right
elif right.is_symbol(REAL):
alias, expr = right, left
else:
raise ValueError(f"Malformed alias {atom}")
if alias in aliases:
msg = f"Multiple aliases {alias}:\n1) {expr}\n2) {aliases[alias]}"
raise ValueError(msg)
aliases[alias] = expr
for var in expr.get_free_variables():
Gsub.add_edge(alias, var)
if len(expr.get_free_variables()) == 0: # constant handled separately
constants.update({alias: expr})
elif atom.is_symbol(BOOL):
pass
else:
raise ValueError(f"Unsupported atom in assignment: {atom}")
# order of substitutions is determined by a topological sort of the digraph
try:
order = [node for node in nx.topological_sort(Gsub) if node in aliases]
except nx.NetworkXUnfeasible:
raise ValueError("Cyclic aliases definition")
convex_formula = mgr.And(inequalities)
for alias in order:
convex_formula = convex_formula.substitute({alias: aliases[alias]})
uncond_weight = uncond_weight.substitute({alias: aliases[alias]})
# substitute all constants
if len(constants) > 0:
uncond_weight = uncond_weight.substitute(constants)
convex_formula = convex_formula.substitute(constants)
inequalities = []
for literal in convex_formula.args():
if literal.is_not():
negated_atom = literal.args()[0]
left, right = negated_atom.args()
if negated_atom.is_le():
atom = mgr.LT(right, left)
elif negated_atom.is_lt():
atom = mgr.LE(right, left)
else:
raise NotImplementedError("Unhandled case")
else:
atom = literal
# Add a bound if the atom is an inequality
if atom.is_le() or atom.is_lt():
inequalities.append(atom)
else:
raise NotImplementedError("Unhandled case")
polytope = Polytope(inequalities, self.poly_parser)
polynomial = self.poly_parser.parse(uncond_weight)
return polytope, polynomial