jax_backend Module#
This module compiles pangolin models into plain-old JAX functions.
- pangolin.jax_backend.ancestor_sample(vars, key=None, size=None)[source]#
Draw exact samples!
- Parameters:
- Returns:
out – Pytree matching structure of
vars, but withjax.ndarrayarrays in place ofRV. IfsizeisNone, then each array will have the same shape as the correspondingRV. Otherwise, each array will have an extra dimension of sizesizeappended at the beginning.
Examples
Sample a constant RV.
>>> x = RV(ir.Constant(1.5)) >>> ancestor_sample(x) Array(1.5, dtype=...)
Sample a PyTree with the RV inside it.
>>> ancestor_sample({'sup': [[x]]}) {'sup': [[Array(1.5, dtype=...)]]}
Draw several samples.
>>> ancestor_sample(x, size=3) Array([1.5, 1.5, 1.5], dtype=...)
Sample several samples from a PyTree with an RV inside it.
>>> ancestor_sample({'sup': x}, size=3) {'sup': Array([1.5, 1.5, 1.5], dtype=...)}
Sample from several random variables at once
>>> y = RV(ir.Add(), x, x) >>> z = RV(ir.Mul(), x, y) >>> print(ancestor_sample({'cat': x, 'dog': [y, z]})) {'cat': Array(1.5, dtype=...), 'dog': [Array(3., dtype=...), Array(4.5, dtype=...)]}
- pangolin.jax_backend.ancestor_sampler(vars)[source]#
Compiles a pytree of RVs into a plain-old JAX function that takes a PNGKey and returns a pytree with the same structure containing a joint sample from the distribution of those RVs.
- Parameters:
- Returns:
out – function mapping a JAX
PRNGKeyto a sample in the form of a pytree ofjax.ndarraymatching the structure and shape ofvars- Return type:
Callable[[JaxArray], Any]
Examples
>>> x = RV(ir.Constant(1.5)) >>> y = RV(ir.Add(), x, x) >>> fun = ancestor_sampler([{'cat': x}, y])
You now have a plain-old JAX function that’s completely independent of pangolin.
>>> key = jax.random.PRNGKey(0) >>> fun(key) [{'cat': Array(1.5, dtype=float32)}, Array(3., dtype=float32)]
You can do normal JAX stuff with it, e.g. vmap it.
>>> print(jax.vmap(fun)(jax.random.split(key, 3))) [{'cat': Array([1.5, 1.5, 1.5], dtype=float32)}, Array([3., 3., 3.], dtype=float32)]
- pangolin.jax_backend.ancestor_log_prob(*vars, **kwvars)[source]#
Given a pytree of vars, create a plain-old JAX function to compute log probabilities
- Parameters:
- Returns:
out – log-prob function that expects
jax.ndarrayarguments matchingvarsandkwargsand returning a scalar.- Return type:
Callable
Examples
>>> loc = ir.RV(ir.Constant(0.0)) >>> scale = ir.RV(ir.Constant(1.0)) >>> x = RV(ir.Normal(),loc,scale) >>> fun = ancestor_log_prob(x)
You now have a plain JAX function that’s completely independent of pangolin. You can evaluate it.
>>> fun(0.0) Array(-0.9189385, dtype=...)
Or you can vmap it.
>>> jax.vmap(fun)(jnp.array([0.0, 0.5])) Array([-0.9189385, -1.0439385], dtype=float32)
Here’s a more complex example:
>>> op = ir.VMap(ir.Normal(), [None,None], 3) >>> y = RV(op,loc,scale) >>> fun = ancestor_log_prob({'x':x, 'y':y}) >>> fun({'x':0.0, 'y':[0.0, 0.5, 0.1]}) Array(-3.8057542, dtype=...)
You can also create a function that uses positional and/or keyword arguments:
>>> fun = ancestor_log_prob(x, cat=y) >>> fun(0.0, cat=[0.0, 0.5, 0.1]) Array(-3.8057542, dtype=...)