jax_backend Module#

This module compiles pangolin models into plain-old JAX functions.

pangolin.jax_backend.ancestor_sample(vars, key=None, size=None)[source]#

Draw exact samples!

Parameters:
  • vars (PyTree[RV]) – a Pytree of RV to sample

  • key (Optional[JaxArray]) – a JAX PRNGKey or None (default)

  • size (Optional[int]) – number of samples to draw (default of None is just a single sample)

Returns:

out – Pytree matching structure of vars, but with jax.ndarray arrays in place of RV. If size is None, then each array will have the same shape as the corresponding RV. Otherwise, each array will have an extra dimension of size size appended at the beginning.

Examples

Sample a constant RV.

>>> x = RV(ir.Constant(1.5))
>>> ancestor_sample(x)
Array(1.5, dtype=...)

Sample a PyTree with the RV inside it.

>>> ancestor_sample({'sup': [[x]]})
{'sup': [[Array(1.5, dtype=...)]]}

Draw several samples.

>>> ancestor_sample(x, size=3)
Array([1.5, 1.5, 1.5], dtype=...)

Sample several samples from a PyTree with an RV inside it.

>>> ancestor_sample({'sup': x}, size=3)
{'sup': Array([1.5, 1.5, 1.5], dtype=...)}

Sample from several random variables at once

>>> y = RV(ir.Add(), x, x)
>>> z = RV(ir.Mul(), x, y)
>>> print(ancestor_sample({'cat': x, 'dog': [y, z]}))
{'cat': Array(1.5, dtype=...), 'dog': [Array(3., dtype=...), Array(4.5, dtype=...)]}
pangolin.jax_backend.ancestor_sampler(vars)[source]#

Compiles a pytree of RVs into a plain-old JAX function that takes a PNGKey and returns a pytree with the same structure containing a joint sample from the distribution of those RVs.

Parameters:

vars (PyTree[RV]) – Pytree of RV to sample

Returns:

out – function mapping a JAX PRNGKey to a sample in the form of a pytree of jax.ndarray matching the structure and shape of vars

Return type:

Callable[[JaxArray], Any]

Examples

>>> x = RV(ir.Constant(1.5))
>>> y = RV(ir.Add(), x, x)
>>> fun = ancestor_sampler([{'cat': x}, y])

You now have a plain-old JAX function that’s completely independent of pangolin.

>>> key = jax.random.PRNGKey(0)
>>> fun(key)
[{'cat': Array(1.5, dtype=float32)}, Array(3., dtype=float32)]

You can do normal JAX stuff with it, e.g. vmap it.

>>> print(jax.vmap(fun)(jax.random.split(key, 3)))
[{'cat': Array([1.5, 1.5, 1.5], dtype=float32)}, Array([3., 3., 3.], dtype=float32)]
pangolin.jax_backend.ancestor_log_prob(*vars, **kwvars)[source]#

Given a pytree of vars, create a plain-old JAX function to compute log probabilities

Parameters:
  • vars (PyTree[RV]) – pytrees of RV

  • kwargs – more pytrees of RV, given as keyword arguments

  • kwvars (PyTree[RV])

Returns:

out – log-prob function that expects jax.ndarray arguments matching vars and kwargs and returning a scalar.

Return type:

Callable

Examples

>>> loc = ir.RV(ir.Constant(0.0))
>>> scale = ir.RV(ir.Constant(1.0))
>>> x = RV(ir.Normal(),loc,scale)
>>> fun = ancestor_log_prob(x)

You now have a plain JAX function that’s completely independent of pangolin. You can evaluate it.

>>> fun(0.0)
Array(-0.9189385, dtype=...)

Or you can vmap it.

>>> jax.vmap(fun)(jnp.array([0.0, 0.5]))
Array([-0.9189385, -1.0439385], dtype=float32)

Here’s a more complex example:

>>> op = ir.VMap(ir.Normal(), [None,None], 3)
>>> y = RV(op,loc,scale)
>>> fun = ancestor_log_prob({'x':x, 'y':y})
>>> fun({'x':0.0, 'y':[0.0, 0.5, 0.1]})
Array(-3.8057542, dtype=...)

You can also create a function that uses positional and/or keyword arguments:

>>> fun = ancestor_log_prob(x, cat=y)
>>> fun(0.0, cat=[0.0, 0.5, 0.1])
Array(-3.8057542, dtype=...)