import copy
from typing import Any, Dict, List, Mapping, Sequence, Set, Tuple
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import LinearConstraint
import sgkit as sg
from demestats.constr import EventTree, constraints_for
Path = Tuple[Any, ...]
Var = Path | Set[Path]
Params = Mapping[Var, float]
def _dict_to_vec(d: Params, keys: Sequence[Var]) -> jnp.ndarray:
"""
Convert dictionary of parameters to vector representation.
Parameters
----------
d : Params
Dictionary mapping parameter paths to values.
keys : Sequence[Var]
Ordered list of parameter paths.
Returns
-------
jax.Array
Vector representation of parameters in the order specified by keys.
Notes
-----
This utility function is used internally to convert between dictionary
and vector representations of parameters for optimization algorithms.
"""
return jnp.asarray([d[k] for k in keys], dtype=jnp.float64)
def _vec_to_dict_jax(v: jnp.ndarray, keys: Sequence[Var]) -> Dict[Var, jnp.ndarray]:
"""
Convert vector to dictionary with JAX arrays.
Parameters
----------
v : jax.Array
Vector of parameter values.
keys : Sequence[Var]
Ordered list of parameter paths.
Returns
-------
Dict[Var, jax.Array]
Dictionary mapping parameter paths to JAX array values.
"""
return {k: v[i] for i, k in enumerate(keys)}
def _vec_to_dict(v: jnp.ndarray, keys: Sequence[Var]) -> Dict[Var, float]:
"""
Convert vector to dictionary with Python floats.
Parameters
----------
v : jax.Array
Vector of parameter values.
keys : Sequence[Var]
Ordered list of parameter paths.
Returns
-------
Dict[Var, float]
Dictionary mapping parameter paths to Python float values.
"""
return {k: float(v[i]) for i, k in enumerate(keys)}
def fold_sfs(data):
"""
Flatten data, find the middle index position, then fold by adding
first+last, second+second-last, etc., reshape back.
"""
original_shape = data.shape
# Step 1: Flatten the array
flattened = data.flatten()
n = len(flattened)
mid = n // 2
if n % 2 == 0: # Even length
result = flattened[:mid] + jnp.flip(flattened[mid:])
else: # Odd length - middle element stays
result = jnp.concatenate([flattened[:mid] + jnp.flip(flattened[mid+1:]), [flattened[mid]]])
# Step 4: Reshape back to original shape
result = jnp.pad(result, (0, n - len(result)), constant_values=0)
result = result.reshape(original_shape)
return result
def joint_sfs_from_vcz(
path_or_ds,
populations,
ploidy = 1,
fold=True,
):
"""
Compute an unfolded joint SFS from a bio2zarr VCZ dataset.
Parameters
----------
path_or_ds : str or xarray.Dataset
Path to .vcz store or an already-loaded sgkit dataset.
populations : list[list[int]]
Each element is a list of sample indices belonging to one population.
Example: [[0,1,2,3,4], [5,6,7,8,9]]
ploidy : int
If you have haploids, set ploidy=1 and ploidy=2 for diploids
Returns
-------
jsfs : np.ndarray
Multidimensional site frequency spectrum.
"""
ds = sg.load_dataset(path_or_ds) if isinstance(path_or_ds, str) else path_or_ds
# Per-sample allele counts: (variants, samples, alleles)
ds = sg.count_call_alleles(ds)
call_allele_count = ds["call_allele_count"].compute().values.astype(np.int64)
pop_ac = []
sample_sizes = []
for sample_idx in populations:
sample_idx = np.asarray(sample_idx, dtype=int)
counts = call_allele_count[:, sample_idx, :].sum(axis=1)
positions = []
for alt_allele in range(1,4):
# Sum within population: (variants, alleles)
try:
ac = counts[:, alt_allele].astype(np.int64)
positions = np.concatenate([positions, ac])
except IndexError:
continue
pop_ac.append(positions.astype(int))
sample_sizes.append(len(sample_idx)*ploidy)
jsfs_shape = tuple(n + 1 for n in sample_sizes)
jsfs = np.zeros(jsfs_shape, dtype=np.int64)
coords = tuple(ac for ac in pop_ac)
np.add.at(jsfs, coords, 1)
jsfs[(0,) * len(sample_sizes)] = 0
jsfs[tuple(sample_sizes)] = 0
if fold == False:
jsfs = fold_sfs(jsfs)
return jsfs
def joint_sfs_from_haploids(
g,
haploid_populations,
polarised=True,
):
"""
Joint SFS using explicit haplotype indices.
Parameters
----------
g: np.ndarray
genotype matrix
haploid_populations : list[list[int]]
Each element is a list of haplotype indices into the flattened
(samples, ploidy) axis.
Example: [[0,1,4,7], [2,3,5,6]]
alt_allele : int
Allele to count, usually 1 for first ALT.
Returns
-------
jsfs : np.ndarray
"""
sample_sizes = [len(pop) for pop in haploid_populations]
jsfs = np.zeros(tuple(n + 1 for n in sample_sizes), dtype=np.int64)
pop_ac = []
for pop in haploid_populations:
pop = np.asarray(pop, dtype=int)
hap = g[:, pop]
tmp = np.array([])
for i in range(1, 4, 1):
ac = (hap==i).sum(axis=1).astype(np.int64)
tmp = np.concatenate([tmp, ac])
pop_ac.append(tmp.astype(np.int64))
# combined = [np.concatenate(pop_ac[0:3]), np.concatenate(pop_ac[3:])]
coords = tuple(ac for ac in pop_ac)
np.add.at(jsfs, coords, 1)
# Make first and last entry 0
jsfs[(0,) * len(sample_sizes)] = 0
jsfs[tuple(sample_sizes)] = 0
if polarised == False:
jsfs = fold_sfs(jsfs)
return jsfs
def create_constraints(demo, paths):
"""
Create constraints for demographic parameters directly from a demes graph and
paths to parameters of interest.
Parameters
----------
demo : demes.Graph
Demographic model.
paths : Iterable[Var]
Parameter paths to include in constraints.
Returns
-------
dictionary
Constraint representation of A @ x = b and A @ x <= b.
Notes
-----
This function builds an EventTree from the demographic model and
extracts equality and inequality constraints for the specified parameters.
Constraints enforce demographic consistency (e.g., positive population
sizes, chronological ordering of events).
See Also
--------
demestats.event_tree.EventTree
demestats.constr.constraints_for
"""
path_order: List[Var] = list(paths)
et = EventTree(demo)
cons = constraints_for(et, *path_order)
return cons
def finite_difference_hessian(f, x0, args_nonstatic, args_static, eps=1e-6):
"""
Compute diagonal Hessian approximation using finite differences.
Parameters
----------
f : callable
Function f(x, args_nonstatic, args_static) returning scalar.
x0 : jax.Array
Point at which to evaluate Hessian.
args_nonstatic : tuple
Non-static arguments to f.
args_static : tuple
Static arguments to f.
eps : float, optional
Finite difference step size (default: 1e-6).
Returns
-------
jax.Array
Diagonal Hessian matrix approximation.
Notes
-----
This function computes only the diagonal elements of the Hessian
using second-order central differences on the gradient. This is
computationally cheaper than computing the full Hessian and is
often sufficient for preconditioning.
See Also
--------
demestats.fit.util.make_whitening_from_hessian
"""
n = len(x0)
diag_H = jnp.zeros(n)
def loglik_static(params, args_nonstatic):
return f(params, args_nonstatic, args_static)
# For diagonal elements ∂²f/∂x_i², we can use central difference on the gradient
grad_f = eqx.filter_jit(jax.grad(loglik_static))
for i in range(n):
# Central difference for ∂²f/∂x_i²
x_plus = x0.at[i].add(eps)
x_minus = x0.at[i].add(-eps)
# Evaluate gradient at perturbed points and take i-th component
grad_plus_i = grad_f(x_plus, args_nonstatic)[i]
grad_minus_i = grad_f(x_minus, args_nonstatic)[i]
diag_H = diag_H.at[i].set((grad_plus_i - grad_minus_i) / (2 * eps))
return jnp.diag(diag_H)
def make_whitening_from_hessian(f, x0, args_nonstatic, args_static, tau=1e-3, lam=1e-3):
"""
Create whitening transformation (preconditioner) from Hessian diagonal approximation.
Parameters
----------
f : callable
Objective function.
x0 : jax.Array
Reference point for Hessian evaluation.
args_nonstatic : tuple
Non-static arguments to f.
args_static : tuple
Static arguments to f.
tau : float, optional
Minimum eigenvalue threshold (default: 1e-3).
lam : float, optional
Regularization added to eigenvalues (default: 1e-3).
Returns
-------
tuple
(L, LinvT) where L is the whitening matrix and LinvT is its
inverse transpose for constraint transformation.
Notes
-----
Computes L = V @ sqrt(Λ) @ V.T where H = V @ Λ @ V.T is the
eigendecomposition of the Hessian. The transformation whitens
the parameter space, making contours more spherical and improving
optimization performance.
"""
H = finite_difference_hessian(f, x0, args_nonstatic, args_static)
H = 0.5 * (H + H.T)
evals, evecs = jnp.linalg.eigh(H)
evals = jnp.maximum(jnp.abs(evals), tau) + lam
L = evecs @ jnp.diag(jnp.sqrt(evals)) @ evecs.T
LinvT = jnp.linalg.solve(L, jnp.eye(L.shape[0])).T
return L, LinvT
def pullback_objective(f, args_static):
"""
Create whitened objective function with gradient.
Parameters
----------
f : callable
Original objective function f(x, args_nonstatic, args_static).
args_static : tuple
Static arguments to f.
Returns
-------
callable
Whitened function g(y, preconditioner_nonstatic, args_nonstatic)
that returns (value, gradient) tuple.
Notes
-----
This function creates a transformed objective that operates in
whitened parameter space. The transformation is:
x = x0 + LinvT @ y, where y are whitened parameters.
The returned function is JIT-compiled and returns both the function
value and its gradient, suitable for use with optimization algorithms.
See Also
--------
demestats.fit.util.make_whitening_from_hessian
"""
def g(y, preconditioner_nonstatic, args_nonstatic):
x0, LinvT = preconditioner_nonstatic
x = x0 + LinvT @ y
return f(x, args_nonstatic, args_static)
g = eqx.filter_jit(eqx.filter_value_and_grad(g))
return g
def apply_jit(f, args_nonstatic, args_static):
def g(x):
x = jnp.atleast_1d(x)
return f(x, args_nonstatic, args_static)
g = eqx.filter_jit(g)
return g
[docs]
def alternative_constraint_rep(A, b):
"""
Returns an alternative representation of inequality constraints with a lower and upper bound.
Depending on the numerical optimizer one would like to use, sometimes
it's more preferable to express inequality constraints explicitly with a
lower and upper bound. The input for the function are inequality constraints of the form
Ax <= b which is exactly what the ``demestats.constr.constraints_for`` function returns.
Parameters
----------
A : array_like
Coefficients for inequalities of the form Ax <= b
b : array_list
Values for inequalities of the form Ax <= b
Notes
-----
Example:
::
# See tutorial for a detailed example
parameters = [
('demes', 0, 'epochs', 0, 'end_size'), # The ancestral population size
('migrations', 0, 'rate'), # Rate of migration from P0 to P1
('demes', 0, 'epochs', 0, 'end_time') # Time of divergence
]
momi3_parameters = [et.variable_for(param) for param in parameters]
constraint = constraints_for(et, *momi3_parameters)
G, h = constraint["ineq"]
A_alt, ub_alt, lb_alt = alternative_constraint_rep(G, h)
print(A_alt)
print("lower bound: ", lb_alt)
print("upper bound: ", ub_alt)
Please refer to the tutorial for a specific example, the above provided codes are just outlines of how to call on the functions.
See Also
--------
demestats.constr.constraint_for
"""
replace_idx = np.ones(len(b), dtype=bool)
A_combined = []
lb_combined = []
ub_combined = []
# only conditions with one SINGLE index that's repeated will be joined together. If a row has two non-zero indices, that must be copied exactly
for i in range(len(A)):
a_row1, b_val1 = A[i], b[i]
idx1 = np.where(a_row1 != 0)[0]
for j in range(i + 1, len(A)): # Start from i+1
a_row2, b_val2 = A[j], b[j]
idx2 = np.where(a_row2 != 0)[0]
if (len(idx1) == 1) and np.array_equal(idx1, idx2) and replace_idx[i]:
replace_idx[i] = False
replace_idx[j] = False
if a_row1[idx1[0]] == -1:
A_combined.append(a_row2)
ub_combined.append(b_val2)
lb_combined.append(b_val1)
else:
A_combined.append(a_row1)
ub_combined.append(b_val1)
lb_combined.append(b_val2)
if replace_idx[i]:
if len(idx1) == 1:
a_row1 = -1 * a_row1
A_combined.append(a_row1)
ub_combined.append(np.inf)
lb_combined.append(b_val1)
else:
A_combined.append(a_row1)
ub_combined.append(b_val1)
lb_combined.append(-np.inf)
A_combined = jnp.array(A_combined)
ub_combined = jnp.array(ub_combined)
lb_combined = jnp.array(lb_combined)
return A_combined, ub_combined, lb_combined
def create_inequalities(A, b, LinvT, x0):
"""
Create linear constraints that follow the format of ``scipy.optimize.LinearConstraint``
Parameters
----------
A : jax.Array or numpy.ndarray
Constraint matrix of shape (m, n).
b : jax.Array or numpy.ndarray
Constraint bounds of shape (m,).
LinvT : jax.Array
Whitening transformation matrix (inverse transpose of preconditioner).
x0 : jax.Array
Reference point in original parameter space.
Returns
-------
scipy.optimize.LinearConstraint
Constraints transformed to whitened space: A_tilde @ y ≤ ub_tilde,
where y = Linv @ (x - x0).
Notes
-----
This function transforms constraints from the original parameter space
to a whitened space where parameters are decorrelated. The transformation
is: y = Linv @ (x - x0), where Linv is derived from Hessian information.
See Also
--------
demestats.fit.util.alternative_constraint_rep
demestats.fit.util.make_whitening_from_hessian
"""
A_combined, ub_combined, lb_combined = alternative_constraint_rep(A, b)
print(A_combined)
print(lb_combined)
print(ub_combined)
A_tilde = A_combined @ LinvT
lb_tilde = lb_combined - A_combined @ x0
ub_tilde = ub_combined - A_combined @ x0
return LinearConstraint(A_tilde, lb_tilde, ub_tilde)
[docs]
def modify_constraints_for_equality(constraint, indices_for_equality):
"""
Returns a modified version of the input ``constraint`` where all parameters associated with
indicies in ``indicies_for_equality`` will now have an equality constraint.
Parameters
----------
constraint : dict
A dictionary of equality and inequality constraints
indices_for_equality : list
List of tuples, where each tuple are the indices of parameters you want to impose an equality constraint
Returns:
dict : A modified ``constraint`` with the new equality constraints
Notes
-----
Example:
::
# See tutorial for a detailed example
parameters = [
('demes', 0, 'epochs', 0, 'end_size'), # The ancestral population size
('migrations', 0, 'rate'), # Rate of migration from P0 to P1
('demes', 0, 'epochs', 0, 'end_time') # Time of divergence
]
momi3_parameters = [et.variable_for(param) for param in parameters]
constraint = constraints_for(et, *momi3_parameters)
# new_constraint will have the 2nd and 3rd variable be constrained to be equal
new_constraint = modify_constraints_for_equality(constraint, [(1, 2)])
print(new_constraint)
Please refer to the tutorial for a specific example, the above provided codes are just outlines of how to call on the functions.
See Also
--------
demestats.constr.constraint_for
"""
# Create a deep copy of the constraint dictionary
constraint_copy = copy.deepcopy(constraint)
# Extract constraints from the copy
A_eq, b_eq = constraint_copy["eq"]
# Build a new equality constraint: rate_0 - rate_1 = 0
for index1, index2 in indices_for_equality:
new_rule = np.zeros((1, A_eq.shape[1]))
new_rule[0, index1] = 1.0
new_rule[0, index2] = -1.0
# Append to the existing constraint matrices
A_eq = np.vstack([A_eq, new_rule])
b_eq = np.concatenate([b_eq, [0.0]])
constraint_copy["eq"] = (A_eq, b_eq)
return constraint_copy
def process_data(cfg_list):
num_samples = len(cfg_list)
deme_names = cfg_list[0].keys()
D = len(deme_names)
cfg_mat = jnp.zeros((num_samples, D), dtype=jnp.int32)
for i, cfg in enumerate(cfg_list):
for j, n in enumerate(deme_names):
cfg_mat = cfg_mat.at[i, j].set(cfg.get(n, 0))
unique_cfg = jnp.unique(cfg_mat, axis=0)
# Find matching indices
def find_matching_index(row, unique_arrays):
matches = jnp.all(row == unique_arrays, axis=1)
return jnp.where(matches)[0][0]
# Vectorize over all rows in `arr`
matching_indices = jnp.array(
[find_matching_index(row, unique_cfg) for row in cfg_mat]
)
return cfg_mat, deme_names, unique_cfg, matching_indices