jax_backend#

This module compiles pangolin models into plain-old JAX functions.

These methods allow an optional biject argument. If True, then constrained continous distributions will be automatically transformed to an unconstrained space. If False, then this will not be done. One can also provide a dictionary to precisely control how distributions are transformed. See the bijectors submodule for more information.

pangolin.jax_backend.ancestor_sample(vars, key=None, size=None, biject=False)[source]#

Draw exact samples.

Parameters:
  • vars (PyTree[RV]) – a Pytree of RV to sample

  • key (Optional[JaxArray]) – a JAX PRNGKey or None (default)

  • size (Optional[int]) – number of samples to draw (default of None is just a single sample)

  • biject (bool | dict) – A boolean determining if constrained continuous distributions should be transformed to an unconstrained space. Or a dict giving a non-default set of bijectors. (Default: False)

Returns:

out – Pytree matching structure of vars, but with jax.ndarray arrays in place of RV. If size is None, then each array will have the same shape as the corresponding RV. Otherwise, each array will have an extra dimension of size size appended at the beginning.

Examples

Sample a constant RV.

>>> x = RV(ir.Constant(1.5))
>>> ancestor_sample(x)
Array(1.5, dtype=...)

Sample a PyTree with the RV inside it.

>>> ancestor_sample({'sup': [[x]]})
{'sup': [[Array(1.5, dtype=...)]]}

Draw several samples.

>>> ancestor_sample(x, size=3)
Array([1.5, 1.5, 1.5], dtype=...)

Sample several samples from a PyTree with an RV inside it.

>>> ancestor_sample({'sup': x}, size=3)
{'sup': Array([1.5, 1.5, 1.5], dtype=...)}

Sample from several random variables at once

>>> y = RV(ir.Add(), x, x)
>>> z = RV(ir.Mul(), x, y)
>>> print(ancestor_sample({'cat': x, 'dog': [y, z]}))
{'cat': Array(1.5, dtype=...), 'dog': [Array(3., dtype=...), Array(4.5, dtype=...)]}
pangolin.jax_backend.ancestor_sampler(vars, biject=False)[source]#

Compiles a pytree of RVs into a plain-old JAX function that takes a PNGKey and returns a pytree with the same structure containing a joint sample from the distribution of those RVs.

Parameters:
  • vars (PyTree[RV]) – Pytree of RV to sample

  • biject (bool | dict)

Returns:

  • out – function mapping a JAX PRNGKey to a sample in the form of a pytree of jax.ndarray matching the structure and shape of vars

  • biject – A boolean determining if constrained continuous distributions should be transformed to an unconstrained space. Or a dict giving a non-default set of bijectors. (Default: False)

Return type:

Callable[[JaxArray], Any]

Examples

>>> x = RV(ir.Constant(1.5))
>>> y = RV(ir.Add(), x, x)
>>> fun = ancestor_sampler([{'cat': x}, y])

You now have a plain-old JAX function that’s completely independent of pangolin.

>>> key = jax.random.PRNGKey(0)
>>> fun(key)
[{'cat': Array(1.5, dtype=float32)}, Array(3., dtype=...)]

You can do normal JAX stuff with it, e.g. vmap it.

>>> print(jax.vmap(fun)(jax.random.split(key, 3)))
[{'cat': Array([1.5, 1.5, 1.5], dtype=float32)}, Array([3., 3., 3.], dtype=...)]
pangolin.jax_backend.ancestor_log_prob(*vars, biject=False, **kwvars)[source]#

Given a pytree of vars, create a plain-old JAX function to compute log probabilities

Parameters:
  • vars (PyTree[RV]) – pytrees of RV

  • biject (bool | dict) – A boolean determining if constrained continuous distributions should be transformed to an unconstrained space. Or a dict giving a non-default set of bijectors. Only available as a keyword argument. (Default: False)

  • kwargs – more pytrees of RV, given as keyword arguments

  • kwvars (PyTree[RV])

Returns:

out – log-prob function that expects jax.ndarray arguments matching vars and kwargs and returning a scalar.

Return type:

Callable

Examples

>>> loc = ir.RV(ir.Constant(0.0))
>>> scale = ir.RV(ir.Constant(1.0))
>>> x = RV(ir.Normal(),loc,scale)
>>> fun = ancestor_log_prob(x)

You now have a plain JAX function that’s completely independent of pangolin. You can evaluate it.

>>> fun(0.0)
Array(-0.9189385, dtype=...)

Or you can vmap it.

>>> jax.vmap(fun)(jnp.array([0.0, 0.5]))
Array([-0.9189385, -1.0439385], dtype=...)

Here’s a more complex example:

>>> op = ir.VMap(ir.Normal(), [None,None], 3)
>>> y = RV(op,loc,scale)
>>> fun = ancestor_log_prob({'x':x, 'y':y})
>>> fun({'x':0.0, 'y':[0.0, 0.5, 0.1]})
Array(-3.8057542, dtype=...)

You can also create a function that uses positional and/or keyword arguments:

>>> fun = ancestor_log_prob(x, cat=y)
>>> fun(0.0, cat=[0.0, 0.5, 0.1])
Array(-3.8057542, dtype=...)

Subpackages#