ICR Optimization¶
This page demonstrates a constrained optimization workflow with demestats. It stands on its own, separate from the ICR tutorial.
This tutorial shows how to create a customizable SciPy optimizer for any demographic model using scipy.minimize. It is not a primer on numerical optimization. SciPy’s API/behaviour can change across
versions — We are not responsible for any updates/errors made to scipy.minimize.
You can find the code in docs/tutorial_code/icr_optimization.ipynb.
Imports¶
from itertools import combinations
from typing import Any, List, Mapping, Set, Tuple
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
from jax import vmap
from loguru import logger
from scipy.optimize import LinearConstraint, minimize
from demestats.fit.util import (
_dict_to_vec,
_vec_to_dict,
_vec_to_dict_jax,
create_inequalities,
make_whitening_from_hessian,
pullback_objective,
)
from demestats.loglik.icr_loglik import icr_loglik
from demestats.icr import ICRCurve
Revisiting the IWM model¶
demo = msp.Demography()
demo.add_population(initial_size=5000, name="anc")
demo.add_population(initial_size=5000, name="P0")
demo.add_population(initial_size=5000, name="P1")
demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate=0.0001)
demo.add_population_split(time=1000, derived=["P0", "P1"], ancestral="anc")
sample_size = 30 # we simulate 30 diploids
samples = {"P0": sample_size, "P1": sample_size}
ts = msp.sim_mutations(
msp.sim_ancestry(
samples=samples, demography=demo,
recombination_rate=1e-8, sequence_length=1e8, random_seed=12
),
rate=1e-8, random_seed=13
)
data, cfg_list = get_tree_from_positions_data_efficient(ts, num_samples=300, k=10, option="random", gap=200000, seed = 8)
Note that the data obtained is a set of k-order first-coalescence times across trees that are separated by a distance of gap.
cfg_list is the associated sampling configuration. Please refer to get_tree_from_positions_data_efficient in the API.
Inspecting and selecting parameters to optimize¶
Inspect the parameters you can work with:
et = EventTree(demo.to_demes())
et.variables
Suppose now we wish to optimize the following parameters, their associated values will be the initial guesses in the optimization process. We collect those parameters into a dictionary:
paths = {frozenset({('demes', 0, 'epochs', 0, 'end_size'),
('demes', 0, 'epochs', 0, 'start_size')}):3000.,
frozenset({('demes', 1, 'epochs', 0, 'end_size'),
('demes', 1, 'epochs', 0, 'start_size')}): 6000.,
frozenset({('demes', 2, 'epochs', 0, 'end_size'),
('demes', 2, 'epochs', 0, 'start_size')}): 4000.}
cons = create_constraints(demo.to_demes(), paths)
The create_constraints function is a helper function that takes in a dictionary of parameters and calls on constraints_for to output the associated constraints. For any contrained optimization method, one needs a set of parameters they wish to optimize, the constraints, a way to compute the expected ICR, and a way to compute the likelihood and its gradient.
Initial required setup¶
###### Part 1 #####
path_order: List[Var] = list(paths) # convert parameters into list
x0 = jnp.array(_dict_to_vec(paths, path_order)) # convert initial values into a vector
lb = jnp.array([0, 0, 0])
ub = jnp.array([1e8, 1e8, 1e8])
cfg_mat, deme_names = process_data(cfg_list)
# Note that we are choosing k = 10, entries of cfg_list MUST add up to the k you select
icr = ICRCurve(demo=demo, k=10)
icr_call = jax.jit(icr.__call__)
##### Part 2 ######
args_nonstatic = (path_order, data, cfg_mat)
args_static = (icr_call, deme_names)
L, LinvT = make_whitening_from_hessian(_compute_icr_likelihood, x0, args_nonstatic, args_static)
preconditioner_nonstatic = (x0, LinvT)
g = pullback_objective(_compute_icr_likelihood, args_static)
y0 = np.zeros_like(x0)
lb_tr = L.T @ (lb - x0)
ub_tr = L.T @ (ub - x0)
For scipy.minimize, the setup requires three parts. Construction of likelihood functions are described at the end of the tutorial.
Part 1¶
While
dictionaryrepresentations provide intuitive tracking of parameter states, most numerical optimizers require vectorized inputs. To address this, parameter dictionaries must be transformed into compatible vector formats. The functionprocess_datatakes a dictionarycfg_listand splits it into two jax compatible vectors.Based on our benchmarking with scipy.minimize, we recommend explicitly defining lower (lb) and upper (ub) bounds to constrain the search space, which significantly improves optimization performance and stability.
Part 2¶
This setup is optional and will depend on the user’s preference. In our experience with scipy.minimize, due to the magnitude difference between parameters such as the migration rate and population sizes, the gradient with respect to each variable will cause parameter updates to be unstable. We implemented a preconditioning method that makes our problem more suitable for optimization. Instead of optimizing over the classical likelihood function _compute_icr_likelihood, we transform the likelihood function and the bounds to instead optimize over a function g with better conditioning. For more information on preconditioning.
Constraints and scipy.optimize.LinearConstraint¶
Use constraints_for to derive the linear constraints for your chosen
parameters. It returns a dict with:
"eq"→(Aeq, beq)for equality constraints"ineq"→(G, h)for inequalities.
These map directly to SciPy’s scipy.optimize.LinearConstraint:
linear_constraints: list[LinearConstraint] = []
Aeq, beq = cons["eq"]
A_tilde = Aeq @ LinvT
b_tilde = beq - Aeq@x0
if Aeq.size:
linear_constraints.append(LinearConstraint(A_tilde, b_tilde, b_tilde))
G, h = cons["ineq"]
if G.size:
linear_constraints.append(create_inequalities(G, h, LinvT, x0, size=len(paths)))
As explained in the Model Constraints, one would use create_inequalities to modify the output of constraints_for into the appropriate scipy.optimize.LinearConstraint format.
Create and run the optimizer¶
The final step is use an optimizer. Here we use scipy.minimize with constrained optimizer "trust-constr". Please refer to scipy.minimize
gtol = 1e-5
xtol = 1e-5
maxiter = 1000
barrier_tol = 1e-5
res = minimize(
fun=neg_loglik,
x0=y0,
jac=True,
args = (g, preconditioner_nonstatic, args_nonstatic, lb_tr, ub_tr),
method=method,
constraints=linear_constraints,
options={
'gtol': gtol,
'xtol': xtol,
'maxiter': maxiter,
'barrier_tol': barrier_tol,
}
)
x_opt = np.array(x0) + LinvT @ res.x
Due to preconditioning, we must transform the variable to inspect our final estimates:
print("Optimal parameters: ", _vec_to_dict(jnp.asarray(res.x), path_order))
print("\nFinal likelihood evaluation: ", res.fun)
print("Optimal parameters as a vector: ", x_opt)
For the simulated example, the estimates are close to the true values:
Optimal parameters:
{frozenset({('demes', 0, 'epochs', 0, 'end_size'), ('demes', 0, 'epochs', 0, 'start_size')}): 5016.814453125,
frozenset({('demes', 1, 'epochs', 0, 'start_size'), ('demes', 1, 'epochs', 0, 'end_size')}): 5238.4287109375,
frozenset({('demes', 2, 'epochs', 0, 'start_size'), ('demes', 2, 'epochs', 0, 'end_size')}): 5025.2666015625}
Final negative log-likelihood evaluation: 430751.8125
Optimal parameters as a vector: [5016.8145 5238.4287 5025.2666]
We have this full pipeline described above wrapped in a single fit_icr function for convenience. See the API documentation for available options and implementation details. The convenience of demestats is that each component of the optimization pipeline can be modified and operated on its own, but if one wants to use the fit_icr function:
from demestats.fit.fit_icr import fit
optimal_params, final_loglikelihood, optimal_params_vector = fit(demo.to_demes(), paths, data, cfg_list, cons, lb, ub, k=10)
Understanding the objective function¶
Below is just an example of the objective function we designed for scipy.minimize to give users a better understanding of creating an inference pipeline.
def _compute_icr_likelihood(vec, args_nonstatic, args_static):
path_order, data, cfg_mat = args_nonstatic
icr_call, deme_names = args_static
params = _vec_to_dict_jax(vec, path_order)
jax.debug.print("param: {vec}", vec=vec)
batched_loglik = vmap(compute_loglik, in_axes=(0, 0, None, None, None))(
data, cfg_mat, params, icr_call, deme_names
)
# for debugging
# jax.debug.print("batched_loglik: {}", batched_loglik)
loss = -jnp.sum(batched_loglik)
jax.debug.print("Loss: {loss}", loss=loss)
return loss
def neg_loglik(vec, g, preconditioner_nonstatic, args_nonstatic, lb, ub):
if jnp.any(vec >= ub):
return jnp.inf, jnp.full_like(vec, 1e10)
if jnp.any(vec <= lb):
return jnp.inf, jnp.full_like(vec, -1e10)
return g(vec, preconditioner_nonstatic, args_nonstatic)
The objective function one would typically use for scipy.minimize would look something like neg_loglik, where vec represents the parameter values, g computes the likelihood and gradient in a transformed space by indirectly calling _compute_icr_likelihood, and the other variables are arguments necessary for computing the likelihood. To limit the parameter search space we define variables lb and ub. The _compute_icr_likelihood function unpacks all of the arguments and evalutes the ICR at every provided k-order first-coalescence time.