blackjax Module#

This module defines a convenient interface to call Blackjax to do inference. You could of course just call pangolin.jax_backend.ancestor_log_prob to get a plain jax function and then call Blackjax yourself. But this module abstracts away all the details.

pangolin.blackjax.sample(vars, given_vars=None, given_vals=None, reduce_fn=None, **options)#

Default version of Calculate.sample that draws 1000 samples.

Parameters:
  • vars (PyTree[RV])

  • given_vars (PyTree[RV])

  • given_vals (PyTree[ArrayLike])

  • reduce_fn (Optional[Callable])

pangolin.blackjax.E(vars, given_vars=None, given_vals=None, **options)#

Default version of Calculate.E that uses 1000 samples.

Parameters:
  • vars (PyTree[RV])

  • given_vars (PyTree[RV])

  • given_vals (PyTree[ArrayLike])

pangolin.blackjax.var(vars, given_vars=None, given_vals=None, **options)#

Default version of Calculate.var that uses 1000 samples.

Parameters:
  • vars (PyTree[RV])

  • given_vars (PyTree[RV])

  • given_vals (PyTree[ArrayLike])

pangolin.blackjax.std(vars, given_vars=None, given_vals=None, **options)#

Default version of Calculate.std that uses 1000 samples.

Parameters:
  • vars (PyTree[RV])

  • given_vars (PyTree[RV])

  • given_vals (PyTree[ArrayLike])

class pangolin.blackjax.Calculate(default=None, frozen=None)[source]#

A Calculate object just remembers a set of options and then offers inference methods.

Parameters:
  • default (Optional[dict]) – represents a set of options for the inference engine that can be overriden later

  • frozen (Optional[dict]) – represents a set of options for the inference engine that cannot be overridden

sample(vars, given_vars=None, given_vals=None, reduce_fn=None, **options)[source]#

Draw samples!

Parameters:
  • vars (PyTree[RV]) – A RV or list/tuple of RV or pytree of RV to sample.

  • given_vars (PyTree[RV]) – A RV or list/tuple of RV or pytree of RV to condition on. None indicates no conditioning variables.

  • given_vals (PyTree[ArrayLike]) – An ArrayLike or list/tuple of ArrayLike or pytree of ArrayLike representing observed values. Must match the structure and shape of given_vars.

  • reduce_fn (Optional[Callable]) – Function to apply to each leaf node in samples before returning. This is used to create E, var, etc. (If None, does nothing.)

  • options – extra options to pass to sampler

Returns:

Pytree of JAX arrays matching structure and shape of vars but with one extra dimension at the start, containing the samples.

Examples

>>> zero    = ir.RV(ir.Constant(0))
>>> one     = ir.RV(ir.Constant(1))
>>> x       = ir.RV(ir.Normal(), zero, one)
>>> y       = ir.RV(ir.Normal(), x, one)
>>> calc    = Calculate({'niter': 529})
>>> x_samps = calc.sample(x,y,2)
>>> x_samps.shape
(529,)
>>> np.mean(x_samps) # something close to 1.0
Array(...)
E(vars, given_vars=None, given_vals=None, **options)[source]#

Compute (conditional) expected values. This is just a thin wrapper that calls sample and then reduces by taking the mean.

Parameters:
  • vars (PyTree[RV]) – A RV or list/tuple of RV or pytree of RV to sample.

  • given_vars (PyTree[RV]) – A RV or list/tuple of RV or pytree of RV to condition on. None indicates no conditioning variables.

  • given_vals (PyTree[ArrayLike]) – An ArrayLike or list/tuple of ArrayLike or pytree of ArrayLike representing observed values. Must match the structure and shape of given_vars.

  • reduce_fn – Function to apply to each leaf node in samples before returning. This is used to create E, var, etc. (If None, does nothing.)

  • options – extra options to pass to sampler

Returns:

Pytree of JAX arrays matching structure and shape of vars, containing

the expectations.

Examples

>>> zero    = ir.RV(ir.Constant(0))
>>> one     = ir.RV(ir.Constant(1))
>>> x       = ir.RV(ir.Normal(), zero, one)
>>> y       = ir.RV(ir.Normal(), x, one)
>>> calc    = Calculate({'niter': 529})
>>> calc.E(x,y,2) # something close to 1.0
Array(...)
var(vars, given_vars=None, given_vals=None, **options)[source]#
Parameters:
  • vars (PyTree[RV])

  • given_vars (PyTree[RV])

  • given_vals (PyTree[ArrayLike])

std(vars, given_vars=None, given_vals=None, **options)[source]#
Parameters:
  • vars (PyTree[RV])

  • given_vars (PyTree[RV])

  • given_vals (PyTree[ArrayLike])

sample_arviz(vars, given_vars=None, given_vals=None, **options)[source]#

This is an experimental function to draw samples in ArviZ format.

Note: ArviZ is not installed with pangolin by default: You must install it manually.

Parameters:
  • vars (dict[str, RV]) – dictionary mapping names to individual random variables given_vars: A RV or list/tuple of RV or pytree of RV to condition on. None indicates no conditioning variables.

  • given_vars (PyTree[RV]) – A RV or list/tuple of RV or pytree of RV to condition on. given_vals: An ArrayLike or list/tuple of ArrayLike or pytree of ArrayLike representing observed values. Must match the structure and shape of given_vars.

  • reduce_fn – Function to apply to each leaf node in samples before returning. This is used to create E, var, etc. (If None, does nothing.)

  • options – extra options to pass to sampler

  • given_vals (PyTree[ArrayLike])

pangolin.blackjax.inf_until_match(inf, vars, given, vals, testfun, niter_start=1000, niter_max=100000)[source]#