jax_backend#
This module compiles pangolin models into plain-old JAX functions.
These methods allow an optional biject argument. If True, then constrained continous distributions will be automatically transformed to an unconstrained space. If False, then this will not be done. One can also provide a dictionary to precisely control how distributions are transformed. See the bijectors submodule for more information.
- pangolin.jax_backend.ancestor_sample(vars, key=None, size=None, biject=False)[source]#
Draw exact samples.
- Parameters:
key (Optional[JaxArray]) – a JAX
PRNGKeyorNone(default)size (Optional[int]) – number of samples to draw (default of
Noneis just a single sample)biject (bool | dict) – A boolean determining if constrained continuous distributions should be transformed to an unconstrained space. Or a dict giving a non-default set of bijectors. (Default: False)
- Returns:
out – Pytree matching structure of
vars, but withjax.ndarrayarrays in place ofRV. IfsizeisNone, then each array will have the same shape as the correspondingRV. Otherwise, each array will have an extra dimension of sizesizeappended at the beginning.
Examples
Sample a constant RV.
>>> x = RV(ir.Constant(1.5)) >>> ancestor_sample(x) Array(1.5, dtype=...)
Sample a PyTree with the RV inside it.
>>> ancestor_sample({'sup': [[x]]}) {'sup': [[Array(1.5, dtype=...)]]}
Draw several samples.
>>> ancestor_sample(x, size=3) Array([1.5, 1.5, 1.5], dtype=...)
Sample several samples from a PyTree with an RV inside it.
>>> ancestor_sample({'sup': x}, size=3) {'sup': Array([1.5, 1.5, 1.5], dtype=...)}
Sample from several random variables at once
>>> y = RV(ir.Add(), x, x) >>> z = RV(ir.Mul(), x, y) >>> print(ancestor_sample({'cat': x, 'dog': [y, z]})) {'cat': Array(1.5, dtype=...), 'dog': [Array(3., dtype=...), Array(4.5, dtype=...)]}
- pangolin.jax_backend.ancestor_sampler(vars, biject=False)[source]#
Compiles a pytree of RVs into a plain-old JAX function that takes a PNGKey and returns a pytree with the same structure containing a joint sample from the distribution of those RVs.
- Parameters:
- Returns:
out – function mapping a JAX
PRNGKeyto a sample in the form of a pytree ofjax.ndarraymatching the structure and shape ofvarsbiject – A boolean determining if constrained continuous distributions should be transformed to an unconstrained space. Or a dict giving a non-default set of bijectors. (Default: False)
- Return type:
Callable[[JaxArray], Any]
Examples
>>> x = RV(ir.Constant(1.5)) >>> y = RV(ir.Add(), x, x) >>> fun = ancestor_sampler([{'cat': x}, y])
You now have a plain-old JAX function that’s completely independent of pangolin.
>>> key = jax.random.PRNGKey(0) >>> fun(key) [{'cat': Array(1.5, dtype=float32)}, Array(3., dtype=...)]
You can do normal JAX stuff with it, e.g. vmap it.
>>> print(jax.vmap(fun)(jax.random.split(key, 3))) [{'cat': Array([1.5, 1.5, 1.5], dtype=float32)}, Array([3., 3., 3.], dtype=...)]
- pangolin.jax_backend.ancestor_log_prob(*vars, biject=False, **kwvars)[source]#
Given a pytree of vars, create a plain-old JAX function to compute log probabilities
- Parameters:
biject (bool | dict) – A boolean determining if constrained continuous distributions should be transformed to an unconstrained space. Or a dict giving a non-default set of bijectors. Only available as a keyword argument. (Default: False)
kwargs – more pytrees of
RV, given as keyword argumentskwvars (PyTree[RV])
- Returns:
out – log-prob function that expects
jax.ndarrayarguments matchingvarsandkwargsand returning a scalar.- Return type:
Callable
Examples
>>> loc = ir.RV(ir.Constant(0.0)) >>> scale = ir.RV(ir.Constant(1.0)) >>> x = RV(ir.Normal(),loc,scale) >>> fun = ancestor_log_prob(x)
You now have a plain JAX function that’s completely independent of pangolin. You can evaluate it.
>>> fun(0.0) Array(-0.9189385, dtype=...)
Or you can vmap it.
>>> jax.vmap(fun)(jnp.array([0.0, 0.5])) Array([-0.9189385, -1.0439385], dtype=...)
Here’s a more complex example:
>>> op = ir.VMap(ir.Normal(), [None,None], 3) >>> y = RV(op,loc,scale) >>> fun = ancestor_log_prob({'x':x, 'y':y}) >>> fun({'x':0.0, 'y':[0.0, 0.5, 0.1]}) Array(-3.8057542, dtype=...)
You can also create a function that uses positional and/or keyword arguments:
>>> fun = ancestor_log_prob(x, cat=y) >>> fun(0.0, cat=[0.0, 0.5, 0.1]) Array(-3.8057542, dtype=...)