sampling
riemax.numerical.sampling
riemax.numerical.sampling.rwmh_sampler(key: jax.random.PRNGKeyArray, n_samples: int, fn: tp.Callable[[jax.Array], float], initial_position: jax.Array, burnin_steps: int = 20000) -> tuple[jax.Array, jax.Array]
Conduct Random-Walk Metropolis Hastings sampling.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
jax.random.PRNGKeyArray
|
random key to use for the sampling procedure |
required |
n_samples |
int
|
number of steps to conduct for the metropolis hastings sampling |
required |
fn |
typing.Callable[[jax.Array], float]
|
function used to determine validity of samples |
required |
initial_position |
jax.Array
|
position at which to commence the random walk |
required |
burnin_steps |
int
|
number of initial steps to discard in the MCMC chain |
20000
|
Returns:
Name | Type | Description |
---|---|---|
pos |
jax.Array
|
positions of points sampled |
do_accept |
jax.Array
|
whether or not to accept the sampled points |