CCR: Cross-Coalescent Rate¶
To fully understand CCR, we recommend you to first look through the ICR page.
This note documents how to use the CCR functions and the difference between the
exact colored-CTMC implementation and the mean-field approximation.
Overview¶
The CCR measures the instantaneous hazard of the first red-blue coalescence.
The API mirrors demestats.icr:
demestats.ccr.CCRCurve: exact colored lineage-count CTMC (accurate but scales poorly with sample size and number of demes).demestats.ccr.CCRMeanFieldCurve: mean-field ODE approximation (scales well, slightly less accurate).
Both return a dict with:
c: the CCR curve evaluated at the requested times.log_s: log survival curvelog P(no cross-coalescence by t).
Background¶
Schiffels and Durbin (2014) introduced the cross-coalescence rate for two populations as a time-dependent coalescence hazard for lineages sampled from different groups. Intuitively, it quantifies how quickly lineages from the two groups find common ancestry as you go back in time, and is widely used as a measure of divergence and gene flow.
Here, CCR generalizes that idea to arbitrary demographies and sampling schemes via a simple thought experiment: we imagine each population is colored either red or blue, and we tag lineages accordingly. We then track the joint process of red and blue lineages across multiple demes and define the curve as the instantaneous hazard of the first red-blue coalescence event. This reduces to the Schiffels-Durbin CCR in the two-sample, two-population setting, while supporting more complex graphs, time-varying sizes, and migration histories.
Usage¶
import jax.numpy as jnp
import stdpopsim
from demestats.ccr import CCRCurve, CCRMeanFieldCurve
demo = stdpopsim.IsolationWithMigration(
NA=5000, N1=4000, N2=1000, T=1000, M12=1e-4, M21=2e-4
).model.to_demes()
t = jnp.linspace(0.0, 2000.0, 200)
num_samples = {"pop1": (1, 0), "pop2": (0, 1)}
exact = CCRCurve(demo, k=2)(t=t, num_samples=num_samples, params={})
mf = CCRMeanFieldCurve(demo, k=2)(t=t, num_samples=num_samples, params={})
CCR: Exact vs mean-field¶
Exact CCR (CCRCurve) tracks the full colored lineage-count CTMC. The state
space grows as (k+1)^(2d) for k total samples and d demes, which becomes
intractable quickly. The implementation guards against this with
DEMESTATS_CCR_MAX_STATES and will error if the state space is too large.
Mean-field CCR (CCRMeanFieldCurve) involves only the expected red/blue counts
per deme using a deterministic ODE. This is much faster and scales to large
k and d, but is an approximation.
Numerical comparison (IWM example)¶
The following example compares the curves on a standard isolation-with-migration demography. In practice, the mean-field approximation tracks the exact curve closely for typical settings.
import numpy as np
import matplotlib.pyplot as plt
rel_err = np.max(
np.abs(np.asarray(mf["c"]) - np.asarray(exact["c"]))
/ np.maximum(np.asarray(exact["c"]), 1e-12)
)
print("max relative error in c:", rel_err)
fig, ax = plt.subplots(figsize=(6.0, 3.5))
ax.plot(t, exact["c"], label="exact CCR", lw=2)
ax.plot(t, mf["c"], label="mean-field CCR", lw=2, linestyle="--")
ax.set_xlabel("time")
ax.set_ylabel("c(t)")
ax.set_title("CCR: exact vs mean-field (IWM)")
ax.legend(frameon=False)
fig.tight_layout()
Power to detect recent migration¶
The mean field CCR curve can be used to infer very recent migration (e.g., within the last 20 generations) when using a large sample size (k=100). The following example demonstrates this power by comparing two IWM models: one with continuous migration until the present, and another where migration ceases 20 generations ago.
import demes
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from demestats.ccr import CCRMeanFieldCurve
def make_model(migration_end_time=0):
b = demes.Builder(description="YRI-CEU like IWM")
b.add_deme("ancestral", epochs=[dict(start_size=15000, end_time=2500)])
b.add_deme("YRI", ancestors=["ancestral"], epochs=[dict(start_size=20000)])
b.add_deme("CEU", ancestors=["ancestral"], epochs=[dict(start_size=10000)])
# Symmetric migration rate of 5e-5
b.add_migration(
demes=["YRI", "CEU"],
rate=5e-5,
start_time=2500,
end_time=migration_end_time
)
return b.resolve()
# Parameters
t_max = 200
t = jnp.linspace(0.0, t_max, 200)
# Parameters
t_max = 1000
t = jnp.linspace(0.0, t_max, 200)
ks = [1, 5, 20, 100, 200]
print(f"Comparing models with ks={ks}...")
# Compute curves for both models for each k
fig, ax = plt.subplots(figsize=(10, 7))
cmap = plt.get_cmap('viridis')
# Avoid darkest/lightest ends if desired, or just use linear
colors = cmap(np.linspace(0, 0.9, len(ks)))
for i, k in enumerate(ks):
num_samples = {"YRI": (k, 0), "CEU": (0, k)}
# Model 1: Continuous Migration
graph_cont = make_model(migration_end_time=0)
mf_cont = CCRMeanFieldCurve(graph_cont, k=2*k)(t=t, num_samples=num_samples)
# Model 2: Truncated Migration
graph_trunc = make_model(migration_end_time=20)
mf_trunc = CCRMeanFieldCurve(graph_trunc, k=2*k)(t=t, num_samples=num_samples)
# Density f(t) = rate(t) * S(t)
dens_cont = mf_cont["c"] * jnp.exp(mf_cont["log_s"])
dens_trunc = mf_trunc["c"] * jnp.exp(mf_trunc["log_s"])
color = colors[i]
ax.plot(t, dens_cont, label=f"k={k}", color=color, lw=2)
# Only plot truncated migration for t >= 20 (where density > 0)
mask = t >= 20
ax.plot(t[mask], dens_trunc[mask], color=color, linestyle="--", lw=1.5, alpha=0.7)
ax.set_title("Resulting Coalescent Density (Solid=Continuous, Dashed=Truncated)")
ax.set_xlabel("Generations ago")
ax.set_ylabel("Cross-Coalescence Density (log scale)")
#ax.set_yscale('log')
ax.axvline(20, color='gray', linestyle=':', alpha=0.5, label="t=20 cutoff")
# Create custom legend
from matplotlib.lines import Line2D
legend_elements = [Line2D([0], [0], color=colors[i], label=f'k={k}') for i, k in enumerate(ks)]
legend_elements.append(Line2D([0], [0], color='black', linestyle='-', label='Continuous'))
legend_elements.append(Line2D([0], [0], color='black', linestyle='--', label='Truncated'))
ax.legend(handles=legend_elements)
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
plt.show()
The plot shows the cross-coalescence density for sample sizes k in {1, 5, 20, 100, 200} on a log scale. The dashed lines represent the truncated migration model, where the probability of recent cross-coalescence drops effectively to zero (falling off the log scale) for t < 20. The solid lines show the continuous migration model. For smaller k, the density is low and the distinction between models is less pronounced in magnitude. However, as k increases, the expected density of cross-coalescence events in the recent past rises significantly, providing a strong, distinguishable signal that allows us to reject the truncated migration model.
Real data analysis¶
The following example demonstrates how to compute the CCR on a real dataset. We use the tree sequence from the “Unified Genomes” dataset (HGDP + 1kG + SGDP + Ancients) restricted to chromosome 20. We compute the minimum cross-coalescence time between YRI (Yoruba in Ibadan, Nigeria) and CEU (Utah Residents (CEPH) with Northern and Western European Ancestry) populations.
import tskit
import numpy as np
import matplotlib.pyplot as plt
def get_min_cross_coalescence_time(tree, samples1, samples2):
"""
Computes the minimum time to the most recent common ancestor (TMRCA)
between any lineage in samples1 and any lineage in samples2 in the given tree.
"""
# 1. Collect all ancestors of samples1
ancestors1 = set()
current_nodes = set(samples1)
while current_nodes:
ancestors1.update(current_nodes)
next_nodes = set()
for u in current_nodes:
p = tree.parent(u)
if p != tskit.NULL and p not in ancestors1:
next_nodes.add(p)
current_nodes = next_nodes
# 2. Traverse up from samples2, finding the minimum time of intersection
min_time = np.inf
current_nodes = set(samples2)
visited2 = set()
while current_nodes:
next_nodes = set()
for u in current_nodes:
if u in ancestors1:
t = tree.time(u)
if t < min_time:
min_time = t
# Any ancestor of u has time > t, so we don't need to continue up from here
# to find a *lower* coalescence time.
else:
p = tree.parent(u)
if p != tskit.NULL and p not in visited2:
next_nodes.add(p)
visited2.add(p)
current_nodes = next_nodes
# Optimization: stop if all next nodes are older than current min_time
if min_time != np.inf:
if not next_nodes:
break
min_next_time = min(tree.time(x) for x in next_nodes)
if min_next_time > min_time:
break
return min_time
# Load the tree sequence
try:
ts = tskit.load("/scratch/unified/hgdp_tgp_sgdp_high_cov_ancients_chr20_p.dated.trees")
print(f"Loaded tree sequence with {ts.num_trees} trees and {ts.num_samples} samples")
# Helper to find population IDs
def get_pop_id(name):
pops_iter = ts.populations() if callable(ts.populations) else ts.populations
for pop in pops_iter:
if pop.metadata:
try:
import json
md = json.loads(pop.metadata.decode('utf-8')) if isinstance(pop.metadata, bytes) else pop.metadata
if md.get('name') == name:
return pop.id
except:
continue
return None
pop1 = get_pop_id("YRI")
pop2 = get_pop_id("CEU")
if pop1 is not None and pop2 is not None:
samples1 = ts.samples(population=pop1)
samples2 = ts.samples(population=pop2)
print(f"Comparing YRI ({len(samples1)} samples) vs CEU ({len(samples2)} samples)")
times = []
# Downsample to ~200 trees uniformly across the sequence to balance resolution and linkage
target_trees = 200
step = max(1, int(ts.num_trees / target_trees))
for i, tree in enumerate(ts.trees()):
if i % step == 0:
t_ccr = get_min_cross_coalescence_time(tree, samples1, samples2)
times.append(t_ccr)
print(f"Computed CCR for {len(times)} trees.")
# Plot Empirical CCR (CDF)
times = np.sort(times)
cdf = np.arange(1, len(times) + 1) / len(times)
plt.figure(figsize=(8, 5))
plt.step(times, cdf, where='post', label="Empirical CCR")
plt.xlabel("Generations ago")
plt.ylabel("Cumulative Probability")
plt.title("Empirical CCR: YRI vs CEU (Chr20)")
plt.grid(True, alpha=0.3)
plt.show()
else:
print("Populations not found.")
except Exception as e:
print(f"Could not load data or run analysis: {e}")
Fitting Demographic Parameters¶
We can now use the empirical CCR curve derived from the real data to fit demographic parameters. Here, we estimate the recent exponential growth rate of the CEU population and the symmetric migration rate between YRI and CEU. We use scipy.optimize to minimize the mean squared error between the empirical CCR and the Mean-Field model prediction.
from scipy.optimize import minimize
import demes
from demestats.ccr import CCRMeanFieldCurve
# Ensure we have the data from the previous step
if 'times' not in locals() or 'samples1' not in locals():
print("Please run the 'Real data analysis' cell first.")
else:
empirical_times = jnp.sort(jnp.array(times))
empirical_cdf = jnp.arange(1, len(empirical_times) + 1) / len(empirical_times)
# Sample sizes from the real data
n_yri = len(samples1)
n_ceu = len(samples2)
k_total = n_yri + n_ceu
num_samples = {"YRI": (n_yri, 0), "CEU": (0, n_ceu)}
# Fixed parameters based on IWM / literature
N_YRI = 20000
N0_CEU = 30000 # Present day effective size approximation
# Build template graph once
b = demes.Builder(description="Parametric Fit")
b.add_deme("ancestral", epochs=[dict(start_size=15000, end_time=2500)])
b.add_deme("YRI", ancestors=["ancestral"], epochs=[dict(start_size=N_YRI)])
# CEU: Exponential growth. Initial dummy start_size that we will bind.
b.add_deme("CEU", ancestors=["ancestral"], epochs=[dict(start_size=1000, end_size=N0_CEU)])
# Migration: Dummy rate that we will bind.
b.add_migration(demes=["YRI", "CEU"], rate=1e-5, start_time=2500, end_time=0)
graph_template = b.resolve()
# Initialize Mean-Field Curve ONCE
mf_template = CCRMeanFieldCurve(graph_template, k=k_total)
# Identify variables to optimize
# Note: Indices might vary if graph construction changes.
# Here: ancestral=0, YRI=1, CEU=2.
path_start_size = ('demes', 2, 'epochs', 0, 'start_size')
path_mig_rate = ('migrations', 0, 'rate')
var_start_size = mf_template.variable_for(path_start_size)
var_mig_rate = mf_template.variable_for(path_mig_rate)
import jax
@jax.jit
def loss_func(params):
r, log_m = params
m = jnp.exp(log_m)
# Calculate dependent parameter
target_start_size = jnp.maximum(100.0, N0_CEU * jnp.exp(-r * 2500.0))
# Bind parameters
# We need to compute the variables into a dictionary.
# Note: variable_for returns a hashable Variable object, which is fine as a key.
param_dict = {
var_start_size: target_start_size,
var_mig_rate: m
}
# Compute model CCR with bound parameters
mf = mf_template(t=empirical_times, num_samples=num_samples, params=param_dict)
# CDF = 1 - S(t) = 1 - exp(log_s)
model_cdf = 1.0 - jnp.exp(mf["log_s"])
# MSE Loss
mse = jnp.mean((empirical_cdf - model_cdf)**2)
# Bounds check / Prior using jnp.where for JIT compatibility
# Penalty for r < -0.01 or r > 0.1
# Penalty for m < 1e-8 or m > 1e-2
out_of_bounds = (r < -0.01) | (r > 0.1) | (m < 1e-8) | (m > 1e-2)
return jnp.where(out_of_bounds, 1e9, mse)
print("Fitting parameters (this may take a minute)...")
# Initial guess: r=0.001, m=1e-5
x0 = jnp.array([0.001, np.log(1e-5)])
# Use Nelder-Mead. Note: scipy passes numpy arrays, JIT handles them.
# We explicitly cast return to float to satisfy scipy if needed, but usually
# minimize works if we wrap it or if it accepts scalar arrays.
# To be safe with scipy, we wrap the JIT function.
def loss_wrapper(x):
return float(loss_func(x))
res = minimize(loss_wrapper, x0, method='Nelder-Mead', tol=1e-4, options={'maxiter': 100, 'disp': True})
est_r = res.x[0]
est_m = np.exp(res.x[1])
print(f"Estimated Growth Rate (r): {est_r:.5f}")
print(f"Estimated Migration Rate (m): {est_m:.2e}")
# Plotting result
fig, ax = plt.subplots(figsize=(8, 5))
# Empirical
ax.step(empirical_times, empirical_cdf, where='post', label="Empirical CCR", color='black', alpha=0.6)
# Best Fit Model curve
t_plot = jnp.linspace(0, max(empirical_times)*1.1, 200)
target_start_size = max(100, N0_CEU * np.exp(-est_r * 2500))
best_params = {
var_start_size: target_start_size,
var_mig_rate: est_m
}
mf_fit = mf_template(t=t_plot, num_samples=num_samples, params=best_params)
cdf_fit = 1.0 - jnp.exp(mf_fit["log_s"])
ax.plot(t_plot, cdf_fit, label=f"Best Fit (r={est_r:.4f}, m={est_m:.2e})", color='red', lw=2)
ax.set_title("CCR Parameter Inference: YRI-CEU")
ax.set_xlabel("Generations ago")
ax.set_ylabel("CDF")
ax.legend()
ax.grid(True, alpha=0.3)
plt.show()