Source code for demestats.fit.fit_icr

from itertools import combinations
from typing import Any, List, Mapping, Set, Tuple

import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
from jax import vmap
from loguru import logger
from scipy.optimize import LinearConstraint, minimize

from demestats.fit.util import (
    _dict_to_vec,
    _vec_to_dict,
    _vec_to_dict_jax,
    create_inequalities,
    make_whitening_from_hessian,
    pullback_objective,
)
from demestats.icr import ICRCurve
from demestats.loglik.icr_loglik import icr_loglik

logger.disable("demestats")


[docs] def get_tree_from_positions_data_efficient( ts, num_samples=50, gap=150000, k=2, seed=5, option="random" ): """ Extract coalescence-time data from trees sampled at regularly spaced genome positions, together with the corresponding population sampling configurations. This function will return an error if there's missing data (i.e. where a genomic position doesn't have a local tree). Parameters ---------- ts : tskit.TreeSequence The input tree sequence from which local trees are sampled. num_samples : int, optional The number of sampling replicates to generate when ``option="random"``. Default is ``50``. gap : int or float, optional The spacing between successive genomic positions at which trees are queried. Default is ``150000``. k : int, optional The number of sampled nodes. Default is ``2``. seed : int, optional The random seed used to initialize JAX-based sampling. Default is ``5``. option : {"random", "all"}, optional Strategy used to generate samples. With ``"random"``, ``num_samples`` random samples of size ``k`` are drawn without replacement. With ``"all"``, every possible combination of ``k`` samples is used. If the total number of haploids (N) in the tree is large, do not use "all" as the number of total sampling configurations is N choose k. Returns ------- tuple A pair ``(data, cfg_list)`` where: ``data`` is a JAX array containing the extracted coalescence times for each sample across queried genomic positions. ``cfg_list`` is a list of dictionaries giving the per-population sample counts. Notes ----- The queried positions are constructed by drawing a random starting position in the first interval of length ``gap`` and then stepping across the genome in increments of ``gap``. For each queried position and each sampled tree sequence, the function records the time to first coalescence across the sampled nodes. When ``option="random"``, repeated random samples are drawn. When ``option="all"``, all N choose ``k`` sample combinations are enumerated, which may be expensive for large ``ts.num_samples``. :: data, cfgs = get_tree_from_positions_data_efficient( ts, num_samples=100, gap=100000, k=2, seed=42, option="random", ) data, cfgs = get_tree_from_positions_data_efficient( ts, gap=50000, k=2, option="all", ) """ key = jax.random.PRNGKey(seed) num_trees = jnp.floor(ts.sequence_length / gap) start_position = jax.random.randint(key, (1,), 1, gap + 1) additive = jnp.arange(num_trees) * gap position_index = start_position + additive key, subkey = jr.split(key) data_list = [] cfg_list = [] pop_cfg = { ts.population(ts.node(n).population).metadata["name"] for n in ts.samples() } if option == "random": for i in range(num_samples): tmp_data_list = [] samples = jax.random.choice( subkey, ts.num_samples, shape=(k,), replace=False ) key, subkey = jr.split(key) subsample_ts = ts.simplify(samples) tmp_pop_cfg = {pop_name: 0 for pop_name in pop_cfg} for j in samples: tmp_pop_cfg[ ts.population(ts.node(j.item(0)).population).metadata["name"] ] += 1 cfg_list.append(tmp_pop_cfg) for pos in position_index: tree = subsample_ts.at(pos) tmp_data_list.append( min( [ tree.time(tree.parent(node)) for node in subsample_ts.samples() ] ) ) data_list.append(tmp_data_list) if option == "all": all_configs = list(combinations(jnp.arange(ts.num_samples), k)) for config in all_configs: subsample_ts = ts.simplify(config) tmp_data_list = [] tmp_pop_cfg = {pop_name: 0 for pop_name in pop_cfg} for j in config: tmp_pop_cfg[ ts.population(ts.node(j.item(0)).population).metadata["name"] ] += 1 cfg_list.append(tmp_pop_cfg) for pos in position_index: tree = subsample_ts.at(pos) tmp_data_list.append( min( [ tree.time(tree.parent(node)) for node in subsample_ts.samples() ] ) ) data_list.append(tmp_data_list) return jnp.array(data_list), cfg_list
[docs] def process_data(cfg_list): """ Convert a list of dictionary sampling configurations into a vectorized form for comptability with JAX. Parameters ---------- cfg_list : sequence of dict A sequence of dictionaries where each dictionary maps deme names to the number of sampled haploids in that configuration. Returns ------- tuple A pair ``(cfg_mat, deme_names)`` where: ``cfg_mat`` is a JAX integer array of shape ``(num_samples, D)`` containing the sampling counts for each configuration and deme. ``deme_names`` is the ordered collection of deme names corresponding to the columns of ``cfg_mat``. Notes ----- The deme ordering is taken from the keys of the first configuration in ``cfg_list`` and is used consistently for every row in the output matrix. If a deme is missing from a later configuration, its count is filled with ``0``. This function is used for converting a list-based representation of sampling configurations into a compact array form suitable for downstream numerical computation. :: cfg_mat, deme_names = process_data([ {"P0": 2, "P1": 0}, {"P0": 1, "P1": 1}, {"P0": 0, "P1": 2}, ]) """ num_samples = len(cfg_list) deme_names = cfg_list[0].keys() D = len(deme_names) cfg_mat = jnp.zeros((num_samples, D), dtype=jnp.int32) for i, cfg in enumerate(cfg_list): for j, n in enumerate(deme_names): cfg_mat = cfg_mat.at[i, j].set(cfg.get(n, 0)) return cfg_mat, deme_names
Path = Tuple[Any, ...] Var = Path | Set[Path] Params = Mapping[Var, float] def _compute_icr_likelihood(vec, args_nonstatic, args_static): path_order, data, cfg_mat = args_nonstatic icr_call, deme_names = args_static params = _vec_to_dict_jax(vec, path_order) jax.debug.print("param: {vec}", vec=vec) batched_loglik = vmap(icr_loglik, in_axes=(0, 0, None, None, None))( data, cfg_mat, params, icr_call, deme_names ) # jax.debug.print("batched_loglik: {}", batched_loglik) loss = -jnp.sum(batched_loglik) jax.debug.print("Loss: {loss}", loss=loss) return loss def neg_loglik(vec, g, preconditioner_nonstatic, args_nonstatic, lb, ub): if jnp.any(vec >= ub): return jnp.inf, jnp.full_like(vec, 1e10) if jnp.any(vec <= lb): return jnp.inf, jnp.full_like(vec, -1e10) return g(vec, preconditioner_nonstatic, args_nonstatic)
[docs] def fit( demo, paths: Params, data, cfg_list, cons, lb, ub, k, *, method: str = "trust-constr", gtol: float = 1e-8, xtol: float = 1e-8, # default 1e-8 maxiter: int = 1000, # default 1000 barrier_tol: float = 1e-8, ): """ Fit demographic model parameters using ICR likelihood optimization. Parameters ---------- demo : demes.Graph ``demes`` model graph. paths : Params Parameter paths to optimize. Each path specifies a demographic parameter in the model. data : array_like array containing the extracted coalescence times for each subsample across queried genomic positions cfg_list : list of dictionaries list of dictionaries giving the per-population sample counts associated with each sample cons : dict Dictionary containing equality and inequality constraints. Expected keys: 'eq' for (Aeq, beq) equality constraints Aeq@x = beq, and 'ineq' for (G, h) inequality constraints G@x <= h. lb : array_like Lower bounds for parameters. ub : array_like Upper bounds for parameters. k : int Sample size method : str, optional Optimization method (default: "trust-constr"). gtol : float, optional Gradient tolerance for convergence (default: 1e-5). xtol : float, optional Parameter tolerance for convergence (default: 1e-5). maxiter : int, optional Maximum number of iterations (default: 1000). barrier_tol : float, optional Barrier tolerance for interior-point methods (default: 1e-5). Returns ------- tuple (params_opt, opt_value, x_opt) where: - params_opt: Dictionary of optimized parameters - opt_value: Optimal negative log-likelihood value - x_opt: Optimized parameter vector Notes ----- This function implements a sophisticated optimization pipeline: 1. Parameter space transformation using Hessian-based whitening 2. Constraint handling with equality and inequality constraints 3. Optional random projections for computational efficiency 4. Boundary enforcement with penalty gradients The optimization is performed in a transformed space where the Hessian is approximately identity, improving convergence rates. """ path_order: List[Var] = list(paths) x0 = _dict_to_vec(paths, path_order) data = jnp.array(data) x0 = jnp.array(x0) lb = jnp.array(lb) ub = jnp.array(ub) cfg_mat, deme_names = process_data(cfg_list) icr = ICRCurve(demo=demo, k=k) icr_call = jax.jit(icr.__call__) args_nonstatic = (path_order, data, cfg_mat) args_static = (icr_call, deme_names) L, LinvT = make_whitening_from_hessian( _compute_icr_likelihood, x0, args_nonstatic, args_static ) preconditioner_nonstatic = (x0, LinvT) g = pullback_objective(_compute_icr_likelihood, args_static) y0 = np.zeros_like(x0) lb_tr = L.T @ (lb - x0) ub_tr = L.T @ (ub - x0) linear_constraints: list[LinearConstraint] = [] Aeq, beq = cons["eq"] A_tilde = Aeq @ LinvT b_tilde = beq - Aeq @ x0 if Aeq.size: linear_constraints.append(LinearConstraint(A_tilde, b_tilde, b_tilde)) G, h = cons["ineq"] if G.size: linear_constraints.append(create_inequalities(G, h, LinvT, x0)) res = minimize( fun=neg_loglik, x0=y0, jac=True, args=(g, preconditioner_nonstatic, args_nonstatic, lb_tr, ub_tr), method=method, constraints=linear_constraints, options={ "gtol": gtol, "xtol": xtol, "maxiter": maxiter, "barrier_tol": barrier_tol, }, ) x_opt = np.array(x0) + LinvT @ res.x print("optimal value: ") print(x_opt) print(res) return _vec_to_dict(jnp.asarray(x_opt), path_order), res.fun, x_opt