from typing import Any, List, Mapping, Set, Tuple
import jax
import jax.numpy as jnp
import numpy as np
from scipy.optimize import LinearConstraint, minimize
from demestats.fit.util import (
_dict_to_vec,
_vec_to_dict,
_vec_to_dict_jax,
create_inequalities,
make_whitening_from_hessian,
pullback_objective,
)
from demestats.loglik.sfs_loglik import (
prepare_projection,
projection_sfs_loglik,
sfs_loglik,
)
from demestats.sfs import ExpectedSFS
Path = Tuple[Any, ...]
Var = Path | Set[Path]
Params = Mapping[Var, float]
def _compute_sfs_likelihood(vec, args_nonstatic, args_static):
"""
Compute negative log-likelihood for SFS parameters for a given
parameter vector. Supports both full SFS computation and projected
approximations.
Parameters
----------
vec : jax.Array
Parameter vector to evaluate.
args_nonstatic : tuple
Tuple containing non-static arguments:
(path_order, proj_dict, input_arrays, sequence_length, theta, projection, afs)
args_static : tuple
Tuple containing static (compile-time) arguments:
(esfs_obj, einsum_str)
Returns
-------
float
Negative log-likelihood value.
Notes
-----
This function is JIT-compiled and used as the core objective function
for optimization. When projection is enabled, it uses random projections
to reduce computational cost. Debug prints are included for parameter
and loss tracing during optimization.
See Also
--------
demestats.fit.fit_sfs.fit
"""
(path_order, proj_dict, input_arrays, sequence_length, theta, projection, afs, folded) = (
args_nonstatic
)
(esfs_obj, einsum_str) = args_static
params = _vec_to_dict_jax(vec, path_order)
# jax.debug.print("Params: {params}", params=vec)
if projection:
loss = -projection_sfs_loglik(
esfs_obj,
params,
proj_dict,
einsum_str,
input_arrays,
sequence_length,
theta,
folded,
)
# jax.debug.print("Loss: {loss}", loss=loss)
return loss
else:
esfs = esfs_obj(params)
loss = -sfs_loglik(afs, esfs, sequence_length, theta, folded)
# jax.debug.print("Loss full sfs: {loss}", loss=loss)
return loss
def neg_loglik(vec, g, preconditioner_nonstatic, args_nonstatic, lb, ub):
"""
Wrapper function that checks parameter boundaries before evaluating
the objective function. Returns infinite loss for parameters outside
bounds to enforce constraints.
Parameters
----------
vec : jax.Array
Parameter vector in transformed space.
g : callable
Pullback objective function to evaluate.
preconditioner_nonstatic : tuple
Preconditioner arguments (x0, LinvT) for transforming parameters.
args_nonstatic : tuple
Non-static arguments for the objective function.
lb : jax.Array
Lower bounds in transformed space.
ub : jax.Array
Upper bounds in transformed space.
Returns
-------
tuple
(loss, gradient) where loss is the negative log-likelihood or
infinity for invalid parameters, and gradient is a vector of
penalty gradients for boundary violations.
Notes
-----
This function serves as the interface between scipy.optimize and
JAX-compiled functions. It provides gradient information for
boundary violations to guide optimization back to feasible regions.
See Also
--------
demestats.fit.fit_sfs.fit
"""
if jnp.any(vec >= ub):
return jnp.inf, jnp.full_like(vec, 1e10)
if jnp.any(vec <= lb):
return jnp.inf, jnp.full_like(vec, -1e10)
return g(vec, preconditioner_nonstatic, args_nonstatic)
[docs]
def fit(
demo,
paths: Params,
afs,
afs_samples,
cons,
lb,
ub,
*,
folded = False,
method: str = "trust-constr",
sequence_length: float = None,
theta: float = None,
projection: bool = False,
num_projections: float = 200,
seed: float = 5,
gtol: float = 1e-5,
xtol: float = 1e-5, # default 1e-8
maxiter: int = 1000, # default 1000
barrier_tol: float = 1e-5,
):
"""
Fit demographic model parameters using SFS likelihood optimization.
Main optimization function that estimates demographic parameters by
maximizing the likelihood of the observed allele frequency spectrum.
Supports both full SFS computation and accelerated projected methods.
Parameters
----------
demo : demes.Graph
``demes`` model graph.
paths : Params
Parameter paths to optimize. Each path specifies a demographic
parameter in the model.
afs : array_like
Observed allele frequency spectrum.
afs_samples : dictionary
Dictionary specifying the number of haploids in each population for the afs
cons : dict
Dictionary containing equality and inequality constraints.
Expected keys: 'eq' for (Aeq, beq) equality constraints Aeq@x = beq,
and 'ineq' for (G, h) inequality constraints G@x <= h.
lb : array_like
Lower bounds for parameters.
ub : array_like
Upper bounds for parameters.
folded : bool, optional
Set equal to True if you have a folded SFS, otherwise the default is False.
method : str, optional
Optimization method (default: "trust-constr").
sequence_length : float, optional
Sequence length. Required for Poisson likelihood when theta is given.
theta : float, optional
Population-scaled mutation rate. If provided, Poisson likelihood
is used instead of multinomial.
projection : bool, optional
Whether to use random projections for acceleration (default: False).
num_projections : int, optional
Number of random projections to use if projection=True (default: 200).
seed : int, optional
Random seed for projection matrix generation (default: 5).
gtol : float, optional
Gradient tolerance for convergence (default: 1e-5).
xtol : float, optional
Parameter tolerance for convergence (default: 1e-5).
maxiter : int, optional
Maximum number of iterations (default: 1000).
barrier_tol : float, optional
Barrier tolerance for interior-point methods (default: 1e-5).
Returns
-------
tuple
(params_opt, opt_value, x_opt) where:
- params_opt: Dictionary of optimized parameters
- opt_value: Optimal negative log-likelihood value
- x_opt: Optimized parameter vector
Notes
-----
This function implements a sophisticated optimization pipeline:
1. Parameter space transformation using Hessian-based whitening
2. Constraint handling with equality and inequality constraints
3. Optional random projections for computational efficiency
4. Boundary enforcement with penalty gradients
The optimization is performed in a transformed space where the Hessian
is approximately identity, improving convergence rates.
"""
path_order: List[Var] = list(paths)
x0 = _dict_to_vec(paths, path_order)
x0 = jnp.array(x0)
lb = jnp.array(lb)
ub = jnp.array(ub)
afs = jnp.array(afs)
esfs_obj = ExpectedSFS(demo, num_samples=afs_samples)
if projection:
proj_dict, einsum_str, input_arrays = prepare_projection(
afs, afs_samples, sequence_length, num_projections, seed
)
else:
proj_dict, einsum_str, input_arrays = None, None, None
args_nonstatic = (
path_order,
proj_dict,
input_arrays,
sequence_length,
theta,
projection,
afs,
folded,
)
args_static = (esfs_obj, einsum_str)
L, LinvT = make_whitening_from_hessian(
_compute_sfs_likelihood, x0, args_nonstatic, args_static
)
preconditioner_nonstatic = (x0, LinvT)
g = pullback_objective(_compute_sfs_likelihood, args_static)
y0 = np.zeros_like(x0)
lb_tr = L.T @ (lb - x0)
ub_tr = L.T @ (ub - x0)
linear_constraints: list[LinearConstraint] = []
Aeq, beq = cons["eq"]
A_tilde = Aeq @ LinvT
b_tilde = beq - Aeq @ x0
if Aeq.size:
linear_constraints.append(LinearConstraint(A_tilde, b_tilde, b_tilde))
G, h = cons["ineq"]
if G.size:
linear_constraints.append(create_inequalities(G, h, LinvT, x0))
res = minimize(
fun=neg_loglik,
x0=y0,
jac=True,
args=(g, preconditioner_nonstatic, args_nonstatic, lb_tr, ub_tr),
method=method,
constraints=linear_constraints,
options={
"gtol": gtol,
"xtol": xtol,
"maxiter": maxiter,
"barrier_tol": barrier_tol,
},
)
x_opt = np.array(x0) + LinvT @ res.x
print("optimal value: ")
print(x_opt)
print(res)
return _vec_to_dict(jnp.asarray(x_opt), path_order), res.fun, x_opt