Source code for demestats.loglik.icr_loglik
import jax.numpy as jnp
[docs]
def icr_loglik(time, sample_config, params, icr_call, deme_names):
"""
Compute the log-likelihood contribution from an ICR evaluation at the given
times and sampling configuration.
Parameters
----------
time : ArrayLike
One or more time points at which the ICR quantities are evaluated.
sample_config : array of int
A array giving the number of sampled haploids for each deme, ordered
consistently with ``deme_names``.
params : dict
A dictionary of parameter values passed through to ``icr_call``.
icr_call : Callable
A callable object, typically an ``ICRCurve`` instance or compatible
function, that accepts ``params``, ``t``, and ``num_samples`` and
returns ICR-related quantities.
deme_names : array of str
The ordered deme names corresponding to entries in ``sample_config``.
Returns
-------
Scalar
The total log-likelihood, computed as the sum of
``log(result["c"]) + result["log_s"]`` over the given time points.
Notes
-----
This function converts the positional sampling configuration into the deme-name
mapping expected by ``icr_call``, evaluates the ICR quantities, and combines
the returned components into a scalar log-likelihood.
The callable ``icr_call`` is an ICR or CCR object that is expected to return a dictionary containing
the entries ``"c"`` and ``"log_s"``. You may also pass in their respective mean-field objects,
Any function that returns 'c' and 'log_s' will work.
::
icr_exact = ICRCurve(demo=g, k=2)
ll = icr_loglik(
time=jnp.array([10.0, 100.0, 1000.0]),
sample_config=[2, 0],
params={},
icr_call=icr_exact,
deme_names=["P0", "P1"],
)
See Also
--------
demestats.iicr.IICRCurve
demestats.iicr.IICRCurve.__call__
demestats.iicr.IICRMeanFieldCurve
demestats.iicr.IICRMeanFieldCurve.__call__
demestats.iicr.CCRCurve
demestats.iicr.CCRCurve.__call__
demestats.iicr.CCRMeanFieldCurve
demestats.iicr.CCRMeanFieldCurve.__call__
"""
ns = {name: sample_config[i] for i, name in enumerate(deme_names)}
result = icr_call(params=params, t=time, num_samples=ns)
# jax.debug.print("result: {}", result)
return jnp.sum(jnp.log(result["c"]) + result["log_s"])