Source code for demestats.loglik.sfs_loglik

import jax.numpy as jnp
import numpy as np  # This can be deleted if I change rng in prepare_projection
from jax.scipy.special import xlogy
from demestats.fit.util import fold_sfs


[docs] def sfs_loglik(afs, esfs, sequence_length=None, theta=None, folded=False): """ This function evaluates the multinomial or Poisson log-likelihood of an observed site frequency spectrum (AFS) given an expected spectrum (ESFS). By default, the sequence length and mutation rate (theta) are None, indicating that the multinomial likelihood will be used. To use the Poisson likelihood, one must provide BOTH the sequence length and mutation rate (theta). Parameters ---------- afs : array_like Observed allele frequency spectrum esfs : array_like Expected allele frequency spectrum. Must be the same shape as ``afs`` sequence_length : int, optional Total number of sites in the sequence. Required if ``theta`` is given theta : float, optional Population-scaled mutation rate. If provided, a sequence length must also be provided and the Poisson likelihood is used; otherwise a multinomial likelihood is assumed. folded : boolean, optional A boolean indicating whether you are whether with folded SFS or not Returns ------- float Log-likelihood of the observed spectrum given the expected spectrum. Notes ----- In tskit, given a tree sequence, to obtain the afs one can use the function :: afs = tree_sequence.allele_frequency_spectrum(*options) To obtain the esfs, with ``momi3`` one must first initialize an ExpectedSFS object with a ``demes`` demographic model and a dictionary of the number of samples used per population. Then one would input a dictionary of parameter values into the Expected SFS object:: ESFS_obj = demestats.sfs.ExpectedSFS(demes_model.to_demes(), num_samples=samples_per_population) params = {param_key: value} esfs = ESFS_obj(params) multinomial_loglik_value = sfs_loglik(afs, esfs) poisson_loglik_value = sfs_loglik(afs, esfs, sequence_length=1e8, theta=1e-8) To compute the gradient, one can use ``jax.grad`` or ``jax.value_and_grad``. All loglikelihood functions are compatible with ``jax``. Please refer to the tutorial for a specific example, the above provided codes are just outlines of how to call on the functions. See Also -------- demestats.sfs.ExpectedSFS """ afs = afs.flatten()[1:-1] esfs = esfs.flatten()[1:-1] if folded: esfs = fold_sfs(esfs) if theta: assert sequence_length tmp = esfs * sequence_length * theta return jnp.sum(-tmp + xlogy(afs, tmp)) else: return jnp.sum(xlogy(afs, esfs / esfs.sum()))
[docs] def prepare_projection(afs, afs_samples, sequence_length, num_projections, seed): """ Creates the specified number of random projection vectors and appropriate inputs for the Einstein summation for tensor operations that are used in ``ExpectedSFS.tensor_prod``. Parameters ---------- afs : array_like Observed allele frequency spectrum esfs : array_like Expected allele frequency spectrum. Must be the same shape as ``afs`` sequence_length : int, optional Total number of sites in the sequence. Required if ``theta`` is given num_projections : int Number of desired random projection vectors seed : int Seed for reproducibility Returns ------- dict dictionary of random projection vectors str string containing axes names separated by commas for Einstein summation list list containing a dictionary specifying number of haploids per population and the afs. This list is a necessary input for the Einstein summation Notes ----- proj_dict contains the random projection vectors that define the low-dimensional subspace for approximating the full expected SFS, einsum_str is a string specifying the Einstein summation for tensor operations, and input_arrays are preprocessed arrays that serve as inputs to the jax.numpy.einsum call, optimized for JAX's just-in-time compilation Example: :: proj_dict, einsum_str, input_arrays = prepare_projection(afs, afs_samples, sequence_length, num_projections, seed) Please refer to ``Random Projection`` section for a specific example, the above provided codes are just outlines of how to call on the functions. See Also -------- demestats.loglik.sfs_loglik.prepare_projection """ rng = np.random.default_rng(seed) proj_dict = {} pop_names = list(afs_samples.keys()) n_dims = afs.ndim for i in range(n_dims): if sequence_length is None: # Multinomial proj_dict[pop_names[i]] = rng.integers( 0, 2, size=(num_projections, afs.shape[i]), dtype=jnp.int32 ) # uniform0 = rng.uniform(0, 1, size=(num_projections, afs.shape[i])) # proj_dict[pop_names[i]] = (uniform0 / uniform0.sum(axis=1, keepdims=True)).astype(jnp.float32) else: proj_dict[pop_names[i]] = rng.integers( 0, 2, size=(num_projections, afs.shape[i]), dtype=jnp.int32 ) # Ask JT if it's fine to leave it like this, don't fix it if it didn't break? :) input_subscripts = ",".join( [f"z{chr(97 + i)}" for i in range(n_dims)] ) # "za,zb,zc" tensor_subscript = "".join([chr(97 + i) for i in range(n_dims)]) # "abc" output_subscript = "z" # "z" einsum_str = f"{input_subscripts},{tensor_subscript}->{output_subscript}" input_arrays = [proj_dict[pop_names[i]] for i in range(n_dims)] + [afs] return proj_dict, einsum_str, input_arrays
[docs] def projection_sfs_loglik( esfs_obj, params, proj_dict, einsum_str, input_arrays, sequence_length=None, theta=None, folded=False, ): """ This function evaluates the **projected** multinomial or Poisson log-likelihood of an observed site frequency spectrum (AFS) given an expected spectrum (ESFS) via Einstein summation. By default, the sequence length and mutation rate (theta) are None, indicating that the multinomial likelihood will be used. To use the Poisson likelihood, one must provide BOTH the sequence length and mutation rate (theta). Parameters ---------- esfs_obj : array_like An demestats.sfs.ExpectedSFS object params : dict a dictionary of model parameters and their values proj_dict : dict Dictionary of arrays that represent projection vectors einsum_str : string Einstein summation string for projection input_arrays : array_like Input arrays for einsum operation, it must contain the original afs sequence_length : int, optional Total number of sites in the sequence. Required if ``theta`` is given theta : float, optional Population-scaled mutation rate. If provided, a sequence length must also be provided and the Poisson likelihood is used, otherwise a multinomial likelihood is assumed. Returns ------- float Log-likelihood of the projected observed spectrum given the projected expected spectrum. Notes ----- proj_dict contains the random projection vectors that define the low-dimensional subspace for approximating the full expected SFS, einsum_str is a string specifying the Einstein summation for tensor operations, and input_arrays are preprocessed arrays that serve as inputs to the jax.numpy.einsum call, optimized for JAX's just-in-time compilation Example: :: proj_dict, einsum_str, input_arrays = prepare_projection(afs, afs_samples, sequence_length, num_projections, seed) esfs_obj = ExpectedSFS(demo.to_demes(), num_samples=afs_samples) params = {param_key: val} projection_sfs_loglik(esfs_obj, params, proj_dict, einsum_str, input_arrays, sequence_length=None, theta=None) Internally this function will call on demestats.sfs.ExpectedSFS.tensor_prod, which performs the projection operations on the site frequency spectrum. Please refer to the tutorial for a specific example, the above provided codes are just outlines of how to call on the functions. See Also -------- demestats.sfs.ExpectedSFS demestats.sfs.ExpectedSFS.tensor_prod demestats.sfs.sfs_loglik.prepare_projection """ result1 = esfs_obj.tensor_prod(proj_dict, params) result2 = jnp.einsum(einsum_str, *input_arrays) if theta: tmp = result1 * sequence_length * theta return jnp.sum(-tmp + xlogy(result2, tmp)) else: return jnp.sum(xlogy(result2, result1 / jnp.sum(result1)))