blackjax Module#
This module defines a convenient interface to call
Blackjax
to do inference. You could of course just call pangolin.jax_backend.ancestor_log_prob
to get a plain jax function and then call Blackjax yourself. But this module abstracts
away all the details.
- pangolin.blackjax.sample(vars, given_vars=None, given_vals=None, reduce_fn=None, **options)#
Default version of
Calculate.samplethat draws 1000 samples.
- pangolin.blackjax.E(vars, given_vars=None, given_vals=None, **options)#
Default version of
Calculate.Ethat uses 1000 samples.
- pangolin.blackjax.var(vars, given_vars=None, given_vals=None, **options)#
Default version of
Calculate.varthat uses 1000 samples.
- pangolin.blackjax.std(vars, given_vars=None, given_vals=None, **options)#
Default version of
Calculate.stdthat uses 1000 samples.
- class pangolin.blackjax.Calculate(default=None, frozen=None)[source]#
A
Calculateobject just remembers a set of options and then offers inference methods.- Parameters:
default (Optional[dict]) – represents a set of options for the inference engine that can be overriden later
frozen (Optional[dict]) – represents a set of options for the inference engine that cannot be overridden
- sample(vars, given_vars=None, given_vals=None, reduce_fn=None, **options)[source]#
Draw samples!
- Parameters:
vars (PyTree[RV]) – A
RVor list/tuple ofRVor pytree ofRVto sample.given_vars (PyTree[RV]) – A
RVor list/tuple ofRVor pytree ofRVto condition on.Noneindicates no conditioning variables.given_vals (PyTree[ArrayLike]) – An
ArrayLikeor list/tuple ofArrayLikeor pytree ofArrayLikerepresenting observed values. Must match the structure and shape ofgiven_vars.reduce_fn (Optional[Callable]) – Function to apply to each leaf node in samples before returning. This is used to create
E,var, etc. (IfNone, does nothing.)options – extra options to pass to sampler
- Returns:
Pytree of JAX arrays matching structure and shape of
varsbut with one extra dimension at the start, containing the samples.
Examples
>>> zero = ir.RV(ir.Constant(0)) >>> one = ir.RV(ir.Constant(1)) >>> x = ir.RV(ir.Normal(), zero, one) >>> y = ir.RV(ir.Normal(), x, one) >>> calc = Calculate({'niter': 529}) >>> x_samps = calc.sample(x,y,2) >>> x_samps.shape (529,) >>> np.mean(x_samps) # something close to 1.0 Array(...)
- E(vars, given_vars=None, given_vals=None, **options)[source]#
Compute (conditional) expected values. This is just a thin wrapper that calls
sampleand then reduces by taking the mean.- Parameters:
vars (PyTree[RV]) – A
RVor list/tuple ofRVor pytree ofRVto sample.given_vars (PyTree[RV]) – A
RVor list/tuple ofRVor pytree ofRVto condition on.Noneindicates no conditioning variables.given_vals (PyTree[ArrayLike]) – An
ArrayLikeor list/tuple ofArrayLikeor pytree ofArrayLikerepresenting observed values. Must match the structure and shape ofgiven_vars.reduce_fn – Function to apply to each leaf node in samples before returning. This is used to create
E,var, etc. (IfNone, does nothing.)options – extra options to pass to sampler
- Returns:
- Pytree of JAX arrays matching structure and shape of
vars, containing the expectations.
- Pytree of JAX arrays matching structure and shape of
Examples
>>> zero = ir.RV(ir.Constant(0)) >>> one = ir.RV(ir.Constant(1)) >>> x = ir.RV(ir.Normal(), zero, one) >>> y = ir.RV(ir.Normal(), x, one) >>> calc = Calculate({'niter': 529}) >>> calc.E(x,y,2) # something close to 1.0 Array(...)
- sample_arviz(vars, given_vars=None, given_vals=None, **options)[source]#
This is an experimental function to draw samples in ArviZ format.
Note: ArviZ is not installed with pangolin by default: You must install it manually.
- Parameters:
vars (dict[str, RV]) – dictionary mapping names to individual random variables given_vars: A
RVor list/tuple ofRVor pytree ofRVto condition on.Noneindicates no conditioning variables.given_vars (PyTree[RV]) – A
RVor list/tuple ofRVor pytree ofRVto condition on. given_vals: AnArrayLikeor list/tuple ofArrayLikeor pytree ofArrayLikerepresenting observed values. Must match the structure and shape ofgiven_vars.reduce_fn – Function to apply to each leaf node in samples before returning. This is used to create
E,var, etc. (IfNone, does nothing.)options – extra options to pass to sampler
given_vals (PyTree[ArrayLike])