Source code for demestats.constr

"Constraints from event tree"

from typing import Any

import numpy as np
from beartype.typing import Sequence, TypedDict
from jaxtyping import ArrayLike, Float
from scipy.optimize import linprog

from .event_tree import EventTree, Variable
from .path import Path


class ConstraintSet(TypedDict):
    var: Sequence[Variable]
    eq: tuple[Float[ArrayLike, "eq d"], Float[ArrayLike, "eq"]]
    ineq: tuple[Float[ArrayLike, "ineq d"], Float[ArrayLike, "ineq"]]


[docs] def constraints_for(et: EventTree, *vars_: Variable) -> ConstraintSet: """ Return a list of constraints for the given variables. Parameters ---------- et : EventTree The event tree to extract constraints from var : Variable The variables for which to extract constraints, which must exist in `et.variables`. Returns: ConstraintSet : A dictionary containing the equality and inequality constraints, as well as a mapping of columns of the constraint matrices to the variables. Notes ----- Example: :: et = EventTree(demo.to_demes()) variables = et.variables final_constraints = constraints_for(et, *variables) Please refer to the tutorial for a specific example, the above provided codes are just outlines of how to call on the functions. """ missing = [v for v in vars_ if v not in et.variables] if missing: raise ValueError( f"Variables {missing} not found in event tree. Use et.variables to see available variables." ) A = [] b = [] G = [] h = [] # we iterate through each (set of) variables. all_variables = list(et.variables) def var_index(p: Path) -> int: return next(j for j, v in enumerate(all_variables) if p == v or p in v) n = len(all_variables) I = np.eye(n) for i, v in enumerate(all_variables): vs = v if isinstance(v, tuple): vs = set([v]) v0 = next(iter(vs)) val = et.get_path(v0) if v not in vars_: # this is a fixed value, so constrain it to be equal to its # initial value if np.isfinite(val): # fixed value A.append(I[i]) b.append(val) continue # this is a variable we optimize over, so add the relevant constraints for path in vs: match path: case (*_, "time" | "start_time" | "end_time"): G.append(-I[i]) h.append(0.0) # find the time immediately above this in the event tree # and constrain time ordering # 1: less than parent node = next( node for node in et.nodes if et.nodes[node]["t"] == path ) parent = node while et.nodes[parent]["t"] in vs: (parent,) = et.T.successors(parent) parent_t = et.nodes[parent]["t"] if np.isfinite(et.get_time(parent)): parent_t_var = var_index(parent_t) G.append(I[i] - I[parent_t_var]) h.append(0.0) # 2: greater than children: traverse down until we find a child # with a different time q = list(et.T.predecessors(node)) while q: child = q.pop() child_t = et.nodes[child]["t"] if et.nodes[child]["t"] in vs: q.extend(et.T.predecessors(child)) continue child_t_var = var_index(child_t) G.append(-I[i] + I[child_t_var]) h.append(0.0) case (*_, "start_size" | "end_size"): # this is a size variable, so constrain it to be non-negative G.append(-I[i]) h.append(0.0) case (*_, "rate") | (*_, "proportions", _): # this is a rate variable or proportion, so constrain it to be in [0, 1] # per-generation migration rates are constrained to be in [0, 1] by spec G.append(-I[i]) h.append(0.0) G.append(I[i]) h.append(1.0) # end for # there is a linear constraint on all proportions adding up to 1 for i, deme in enumerate(et.demo.demes): if deme.proportions: paths = [ ("demes", i, "proportions", j) for j, _ in enumerate(deme.proportions) ] indices = map(var_index, paths) A.append(sum(I[j] for j in indices)) b.append(1.0) for i, pulse in enumerate(et.demo.pulses): paths = [ ("pulses", i, "proportions", j) for j, _ in enumerate(pulse.proportions) ] indices = map(var_index, paths) A.append(sum(I[j] for j in indices)) b.append(1.0) A, b, G, h = map(np.array, (A, b, G, h)) A, G = [x.reshape(-1, n) for x in (A, G)] i_r = list(map(all_variables.index, vars_)) # indices of variables i_f = [p for p in range(n) if p not in i_r] # fixed indices xf = [et.get_var(var) for var in all_variables if var not in vars_] A_r = A[:, i_r] b_r = b - A[:, i_f].dot(xf) G_r = G[:, i_r] h_r = h - G[:, i_f].dot(xf) A_r, b_r, G_r, h_r = _reduce(A_r, b_r, G_r, h_r) return dict( eq=(A_r, b_r), ineq=(G_r, h_r), )
def _reduce( A: Float[ArrayLike, "E P"], b: Float[ArrayLike, "E"], G: Float[ArrayLike, "I P"], h: Float[ArrayLike, "I"], ) -> tuple[ Float[ArrayLike, "Er P"], Float[ArrayLike, "Er"], Float[ArrayLike, "Ir P"], Float[ArrayLike, "Ir"], ]: """Reduce the system of equations Ax = b by removing redundant equalities.""" U, s, Vt = np.linalg.svd(A) rank = np.sum(s > 1e-10) A_red = U[:, :rank].T @ A b_red = U[:, :rank].T @ b G_red = G h_red = h def f(): nonlocal G_red, h_red for i in range(len(h_red)): if _is_redundant_inequality(A_red, b_red, G_red, h_red, i): G_red = np.delete(G_red, i, axis=0) h_red = np.delete(h_red, i) return False return True while not f(): pass return A_red, b_red, G_red, h_red def _is_redundant_inequality(A, b, G, h, i, tol=1e-7): c = -G[i] # maximize g_i^T x => minimize -g_i^T x # Remove i-th inequality G_rest = np.delete(G, i, axis=0) h_rest = np.delete(h, i) res = linprog( c, A_eq=A, b_eq=b, A_ub=G_rest, b_ub=h_rest, bounds=(None, None), method="highs" ) if not res.success: return False # LP failed or infeasible, be conservative max_value = -res.fun return max_value <= h[i] + tol def display_constraint( A: np.ndarray, b: np.ndarray, x: list[Any], equality: bool ) -> None: """ Compact display of constraints without variable definitions. """ m, n = A.shape print("\nCONSTRAINTS:") print("-" * 50) for i in range(m): # Build expression expr_parts = [] for j in range(n): coeff = A[i, j] if coeff != 0: if coeff == 1: expr_parts.append(f"x{j + 1}") elif coeff == -1: expr_parts.append(f"-x{j + 1}") else: if coeff.is_integer(): coeff_str = str(int(coeff)) else: coeff_str = f"{coeff:.3f}".rstrip("0").rstrip(".") expr_parts.append(f"{coeff_str}·x{j + 1}") if not expr_parts: expr = "0" else: expr = " + ".join(expr_parts).replace("+ -", " - ") # Format RHS if b[i].is_integer(): rhs = str(int(b[i])) else: rhs = f"{b[i]:.3f}".rstrip("0").rstrip(".") if equality: print(f"Row {i + 1}: {expr} = {rhs}") else: print(f"Row {i + 1}: {expr} <= {rhs}") print("-" * 50) def display_constraint_strings( A: np.ndarray, b: np.ndarray, x: list[Any], equality: bool ) -> None: """ Return list of constraint strings. """ m, n = A.shape constraints = [] print("\nAS STRINGS:") print("-" * 50) for i in range(m): terms = [] for j in range(n): coeff = A[i, j] if coeff != 0: if coeff == 1: terms.append(f"{x[j]}") elif coeff == -1: terms.append(f"-{x[j]}") else: terms.append(f"{coeff} * {x[j]}") if not terms: lhs = "0" else: lhs = " + ".join(terms).replace("+ -", " - ") if equality: constraints.append(f"{lhs} = {b[i]}") else: constraints.append(f"{lhs} <= {b[i]}") for i, constr in enumerate(constraints): print(f"Row {i + 1}: {constr}") print("-" * 50)