import operator
from collections import OrderedDict
from collections.abc import Mapping
from dataclasses import dataclass, field
from functools import partial, reduce
import demes
import equinox as eqx
import jax
import jax.numpy as jnp
from beartype.typing import Sequence
from jaxtyping import Array, ArrayLike, Float, Int, PyTree, ScalarLike, Shaped
from loguru import logger
from penzai import pz
import demestats.event_tree as event_tree
import demestats.sfs.events as events
from demestats.path import Path
from demestats.traverse import traverse
from .events.state import SetupState, State
Params = dict[event_tree.Variable, ScalarLike]
PruneSpec = tuple[str, int, Path | event_tree.Node | None]
PruneInput = Mapping[str, int] | Sequence[PruneSpec] | None
[docs]
@dataclass
class ExpectedSFS:
"""
Build an ExpectedSFS object that can be later used to compute the full expected
site frequency spectrum or the projected site frequency spectrum.
Parameters
----------
demo : demes.Graph
A ``demes`` graph
num_samples : dict
A dictionary specifying how many haploids per population to use to compute
the expected SFS. The name of the
populations must match the exact names used in ``demo``.
prune : mapping or sequence, optional
Optional manual downsampling events. Provide either a mapping
``{deme_name: m}`` to downsample directly above leaves, or a sequence of
``(deme_name, m[, at])`` tuples where ``at`` is a node id or a demes time
path to insert the downsample event above that node.
Returns:
ExpectedSFS: an ExpectedSFS object used to compute the expected site frequency spectrum
Notes
-----
From a user perspective, understanding the underlying structure of an ExpectedSFS object
is not necessary. The only functions that a user would use is ``ExpectedSFS.__call__``
which computes the full expected site frequency spectrum and ``ExpectedSFS.tensor_prod`` which
computes the projected site frequency spectrum.
Example:
::
ESFS = ExpectedSFS(demo.to_demes(), num_samples=afs_samples)
Please refer to the tutorial for a specific example, the above provided codes are just outlines of how to call on the functions.
See Also
--------
demestats.sfs.ExpectedSFS.__call__
demestats.sfs.ExpectedSFS.tensor_prod
"""
demo: demes.Graph
num_samples: dict[str, Int[ScalarLike, ""]]
prune: PruneInput = None
et: event_tree.EventTree = field(init=False)
def __post_init__(self):
if not (self.num_samples.keys() <= {pop.name for pop in self.demo.demes}):
raise ValueError(
"num_samples must contain only deme names from the demo, found {} which is not in {}".format(
self.num_samples.keys() - {pop.name for pop in self.demo.demes},
{pop.name for pop in self.demo.demes},
)
)
self.et = et = event_tree.EventTree(self.demo, events=events)
# increase migration sample sizes, we need >= 4
# for continuous migration, we require that there are at least four nodes.
# so for now we just enforce this globally. slightly wasteful if there is
# not any cm 🤷.
leaves = et.leaves
for j, deme in enumerate(self.demo.demes):
pop = deme.name
if self.num_samples.get(pop, 0) >= 4:
continue
if (
len(
[m for m in self.demo.migrations if deme.name in [m.source, m.dest]]
)
== 0
):
continue
logger.debug("Upsampling {} to 4 samples", pop)
node = leaves[pop]
# deme participates in migrations and has fewer than 4 samples, so we add an upsampling
# event right above the leaf
kw = {k: et.nodes[node][k] for k in ["t", "block", "ti"]}
kw["event"] = events.Upsample(pop=pop, m=4)
v = et._add_node(**kw)
(parent,) = et._T.successors(node)
label = et.edges[node, parent].get("label")
et._remove_edge(node, parent)
et._add_edge(node, v)
et._add_edge(v, parent)
# reconstruct the label if necessary
if label is not None:
et.edges[v, parent]["label"] = label
et._check()
if self.prune:
self._apply_pruning(self.prune)
self._aux = self._setup()
def _apply_pruning(self, prune: PruneInput) -> None:
specs = self._normalize_prune(prune)
if not specs:
return
et = self.et
for pop, m, at in specs:
if pop not in {d.name for d in self.demo.demes}:
raise ValueError(f"unknown deme {pop} in pruning spec")
if isinstance(at, tuple):
target = self._resolve_prune_target(pop, at)
self._insert_downsample(
target["node"],
pop,
m,
t_override=target.get("t_override"),
ti_override=target.get("ti_override"),
)
else:
node = self._resolve_prune_node(pop, at)
self._insert_downsample(node, pop, m)
et.__dict__.pop("T_reduced", None)
def _normalize_prune(self, prune: PruneInput) -> list[PruneSpec]:
if prune is None:
return []
if isinstance(prune, Mapping):
return [(pop, int(m), None) for pop, m in prune.items()]
specs = []
for item in prune:
if len(item) == 2:
pop, m = item
at = None
elif len(item) == 3:
pop, m, at = item
else:
raise ValueError("prune entries must be (pop, m) or (pop, m, at)")
specs.append((pop, int(m), at))
return specs
def _resolve_prune_node(
self, pop: str, at: Path | event_tree.Node | None
) -> event_tree.Node:
et = self.et
if at is None:
return et.leaves[pop]
if isinstance(at, int):
if at not in et.nodes:
raise ValueError(f"prune node {at} not in event tree")
if pop not in et.nodes[at]["block"]:
raise ValueError(f"deme {pop} not in node {at} block")
return at
raise ValueError(f"invalid prune locator {at!r}")
def _resolve_prune_target(self, pop: str, at: Path) -> dict:
et = self.et
matches = [
n
for n in et.nodes
if et.nodes[n]["t"] == at and pop in et.nodes[n]["block"]
]
if len(matches) == 1:
return {"node": matches[0]}
if len(matches) > 1:
raise ValueError(
f"prune locator {at} matched {len(matches)} nodes for {pop}: {matches}"
)
edge_matches = []
for child, parent in et.edges:
if et.nodes[parent]["t"] != at:
continue
if pop not in et.nodes[child]["block"]:
continue
edge_matches.append((child, parent))
if len(edge_matches) != 1:
raise ValueError(
f"prune locator {at} matched {len(edge_matches)} edges for {pop}: {edge_matches}"
)
child, parent = edge_matches[0]
return {
"node": child,
"t_override": at,
"ti_override": et.nodes[parent].get("ti"),
}
def _insert_downsample(
self,
node: event_tree.Node,
pop: str,
m: int,
t_override: Path | None = None,
ti_override: int | None = None,
) -> None:
et = self.et
parents = list(et._T.successors(node))
if len(parents) != 1:
raise ValueError(f"node {node} has no parent to attach pruning")
(parent,) = parents
kw = {k: et.nodes[node][k] for k in ["t", "block", "ti"]}
if t_override is not None:
kw["t"] = t_override
if ti_override is not None:
kw["ti"] = ti_override
kw["event"] = events.Downsample(pop=pop, m=m)
v = et._add_node(**kw)
label = et.edges[node, parent].get("label")
et._remove_edge(node, parent)
et._add_edge(node, v)
et._add_edge(v, parent)
if label is not None:
et.edges[v, parent]["label"] = label
et._check()
def bind(self, params: Params) -> dict:
"""
Bind the parameters to the event tree's demo.
"""
return self.et.bind(params, rescale=True)
def variable_for(self, path: Path) -> event_tree.Variable:
"""Return the variable associated with a given path."""
return self.et.variable_for(path)
@property
def variables(self) -> Sequence[event_tree.Variable]:
"""
Return the parameters that can be optimized.
"""
return self.et.variables
def _setup(self) -> dict[tuple[event_tree.Node, ...], dict]:
setup_state = {}
for pop, leaf in self.et.leaves.items():
n = self.num_samples.get(pop, 0)
setup_state[(leaf,)] = SetupState(
migrations=frozenset(),
axes=OrderedDict({pop: n + 1}),
ns={pop: {pop: n}},
)
_, aux = traverse(
self.et,
setup_state,
node_callback=lambda node, node_attrs, **kw: node_attrs["event"].setup(
demo=self.et.demodict,
**kw,
),
lift_callback=partial(events.setup_lift, demo=self.et.demodict),
aux=None,
scan_over_lifts=False,
)
return aux
def __call__(self, params: Params = {}) -> Float[Array, "*T"]:
"""
Computes the full expected site frequency spectrum under a given set of model parameters and values
Parameters
----------
params : dict
A dictionary of model parameters and their value
Returns:
Float[Array]: An array of float values that represent the full expected site frequency spectrum
Notes
-----
You must first construct an ExpectedSFS object. See the ExpectedSFS API.
Example:
::
ESFS = ExpectedSFS(demo.to_demes(), num_samples=afs_samples)
params = {param_key: val}
esfs = ESFS(params)
Please refer to the tutorial for a specific example, the above provided codes are just outlines of how to call on the functions.
See Also
--------
demestats.sfs.ExpectedSFS
"""
bs = [n + 1 for n in self.num_samples.values()]
num_derived = jnp.indices(bs)
num_derived = jnp.rollaxis(num_derived, 0, num_derived.ndim).reshape(
-1, len(bs)
)
def f(nd):
nd = dict(zip(self.num_samples, nd))
ret = {}
for pop, leaf in self.et.leaves.items():
# some ghost populations may not be sampled. then they have trivial partial leaf likelihood.
n = self.num_samples.get(pop, 0)
d = nd.get(pop, 0)
ret[pop] = jax.nn.one_hot(jnp.array([d]), n + 1)[0]
return ret
X = jax.vmap(f)(num_derived)
res = self.dp(params, X)
return res.at[jnp.array([0, -1])].set(0.0).reshape(bs)
def tensor_prod(
self,
X: PyTree[Shaped[ArrayLike, "B ?D"], "T"],
params: Params = {},
) -> Float[Array, "B"]:
"""
Computes the projected expected site frequency spectrum under a given set random
projection vectors and model parameters. A tensor product operation between the random projections and
the expected site frequency spectrum is applied to obtain the projected
SFS. To obtain the appropriate projection vectors, one can
use the function ``demestats.loglik.sfs_loglik.prepare_projection``.
Parameters
----------
X: dict
A dictionary of random projection vectors
params : dict
A dictionary of model parameters and their value
Returns:
Float[Array]: An array of float values that represent the projected expected site frequency spectrum
Notes
-----
You must first construct an ExpectedSFS object. See the ExpectedSFS API.
Example:
::
proj_dict, einsum_str, input_arrays = prepare_projection(afs, afs_samples, sequence_length, num_projections, seed)
esfs_obj = ExpectedSFS(demo.to_demes(), num_samples=afs_samples)
lowdim_esfs = esfs_obj.tensor_prod(proj_dict, paths)
Please refer to ``Random Projection`` section for a specific example, the above provided codes are just outlines of how to call on the functions.
See Also
--------
demestats.loglik.sfs_loglik.prepare_projection
"""
demo = self.bind(params)
for pop in X:
n = self.num_samples.get(pop, 0)
assert X[pop].shape[1] == n + 1
def f(v):
u = jnp.eye(v.shape[1])[jnp.array([0, -1])]
return jnp.concatenate([u, v])
Xa = jax.tree.map(f, X)
states = _call(
Xa,
self.et,
demo,
self._aux,
0.0,
)
Pi = reduce(operator.mul, jax.tree.map(lambda a: a[:, [0, -1]], X).values())
ret = states.phi[2:]
ret -= Pi[:, 0] * states.phi[0]
ret -= Pi[:, 1] * states.phi[1]
return ret * self.et.scaling_factor
def dp(
self,
params: Params,
X: dict[str, Float[ArrayLike, "batch *T"]],
) -> Float[Array, "batch"]:
demo = self.bind(params)
state = _call(
X,
self.et,
demo,
self._aux,
0.0,
)
return state.phi * self.et.scaling_factor
@eqx.filter_jit
@eqx.filter_vmap(in_axes=(0,) + (None,) * 4)
def _call(
X: dict[str, Float[Array, "T"]],
et: event_tree.EventTree,
demo: dict,
aux: dict,
phi0: ArrayLike,
) -> State:
states = {}
for pop, node in et.leaves.items():
Xp = X.get(pop, jnp.ones(1))
states[node,] = State(
pl=pz.nx.wrap(Xp, pop),
phi=phi0,
l0=Xp[0],
)
def node_callback(node, node_attrs, **kw):
kw["demo"] = demo
return node_attrs["event"](**kw)
lift_callback = partial(events.lift, demo=demo)
states, _ = traverse(et, states, node_callback, lift_callback, aux=aux)
return states[et.root,]