Skip to content

sampling

riemax.numerical.sampling

riemax.numerical.sampling.rwmh_sampler(key: jax.random.PRNGKeyArray, n_samples: int, fn: tp.Callable[[jax.Array], float], initial_position: jax.Array, burnin_steps: int = 20000) -> tuple[jax.Array, jax.Array]

Conduct Random-Walk Metropolis Hastings sampling.

Parameters:

Name Type Description Default
key jax.random.PRNGKeyArray

random key to use for the sampling procedure

required
n_samples int

number of steps to conduct for the metropolis hastings sampling

required
fn typing.Callable[[jax.Array], float]

function used to determine validity of samples

required
initial_position jax.Array

position at which to commence the random walk

required
burnin_steps int

number of initial steps to discard in the MCMC chain

20000

Returns:

Name Type Description
pos jax.Array

positions of points sampled

do_accept jax.Array

whether or not to accept the sampled points

Source code in src/riemax/numerical/sampling.py
def rwmh_sampler(
    key: jax.random.PRNGKeyArray,
    n_samples: int,
    fn: tp.Callable[[jax.Array], float],
    initial_position: jax.Array,
    burnin_steps: int = 20_000,
) -> tuple[jax.Array, jax.Array]:
    """Conduct Random-Walk Metropolis Hastings sampling.

    Parameters:
        key: random key to use for the sampling procedure
        n_samples: number of steps to conduct for the metropolis hastings sampling
        fn: function used to determine validity of samples
        initial_position: position at which to commence the random walk
        burnin_steps: number of initial steps to discard in the MCMC chain

    Returns:
        pos: positions of points sampled
        do_accept: whether or not to accept the sampled points
    """

    def mh_update(state, _):
        key, pos, log_prob, do_accept = state

        key, _ = jax.random.split(key)
        new_position, new_log_prob, do_accept = _rwmh_kernel(key, fn, pos, log_prob)

        return (key, new_position, new_log_prob, do_accept), state

    initial_state = (key, initial_position, jnp.log(fn(initial_position)), True)
    burnin_state, _ = jax.lax.scan(mh_update, initial_state, None, burnin_steps)

    _, (_, pos, _, do_accept) = jax.lax.scan(mh_update, burnin_state, None, n_samples)

    return pos, do_accept