Source code for demestats.ccr.exact

"""Exact CCR implementation (colored lineage-count CTMC)."""

import os
from dataclasses import dataclass, field
from functools import partial

import demes
import jax

import demestats.event_tree as event_tree
from demestats.traverse import traverse

from ...icr.state import SetupState
from ..curve import CCRCurveBase
from . import events, interp, lift, state


[docs] @jax.tree_util.register_dataclass @dataclass class CCRCurve(CCRCurveBase): """ Build an CCRCurve object that can be later used to evaluate the cross-coalescence rate (CCR) through time for a demographic model. Parameters ---------- demo : demes.Graph A ``demes`` graph describing the demographic history. k : int The number of sampled lineages used to define the CCR curve. Returns ------- CCRCurve An ``CCRCurve`` object that can be called directly on a time grid, sampling configuration, and parameter value. Notes ----- From a user perspective, understanding the underlying structure of an CCRCurve object is not necessary. The only function that a user would use is ``__call__``, which evaluates the CCR on a grid of time points given a sampling configuration. One can choose between two computational backends depending on the size of the problem, see "CCR: Exact vs Mean-Field" in the CCR tutorial. :: # pairwise coalescence uses k = 10 ccr = CCRCurve(demo.to_demes(), k=10) See Also -------- demestats.ccr.CCRMeanFieldCurve """ demo: demes.Graph k: int = field(metadata=dict(static=True)) events_mod = events lift_mod = lift scan_over_lifts = False def _setup_aux( self, et: event_tree.EventTree ) -> dict[tuple[event_tree.Node, ...], dict]: setup_state = { (node,): SetupState(n=self.k, pops=(leaf,)) for leaf, node in et.leaves.items() } def node_cb(node, node_attrs, **kw): return node_attrs["event"].setup(**kw, demo=et.demodict) return traverse( et, setup_state, node_callback=node_cb, lift_callback=partial(lift.setup, demo=et.demodict), aux=None, scan_over_lifts=self.scan_over_lifts, )[1] def _map_times(self, f, t): # Avoid vmapping `expm_multiply` for large CCR state spaces, which can # cause massive peak memory usage (especially on GPU). max_vmap = int(os.environ.get("DEMESTATS_CCR_VMAP_MAX_T", "32")) if t.shape[0] <= max_vmap: return jax.vmap(f)(t) return jax.lax.map(f, t)
__all__ = ["CCRCurve", "events", "interp", "lift", "state"]