jax_backend.bijectors#

This should be a dictionary mapping primitive random Op classes to JaxBijector instances. This is commonly used to transform constrained random Op (like Dirichlet or Uniform) to an unconstrained space to make gradient-based inference easier. This dictionary should cover all the base random, excluding Composite, VMap and Scan.

class pangolin.jax_backend.bijectors.JaxBijector(forward, inverse, log_det_jac)[source]#

The idea is that if P(X) is some density and Y=T(X) is a diffeomorphism, then P(Y=y) = P(X=T⁻¹(y)) × |det ∇T⁻¹(y)|

Parameters:
  • forward – jax function implementing forward transformation given x

  • inverse – jax function implementing inverse transformation given y

  • log_det_jax – jax function implementing the log determinant of the Jacobian of the forward transformation given both x and y (may use either as convenient)

forward(x)[source]#

Computes T(x).

inverse(y)[source]#

Computes T⁻¹(y)

log_det_jac(x, y)[source]#

Computes log |det ∇T(x)| = -log |det ∇T⁻¹(y)|. May use either x or y as convenient.

forward_and_log_det_jac(x)[source]#

Computes T(x) and log |det ∇T(x)|.

inverse_and_log_det_jac(y)[source]#

Computes T⁻¹(y) and log |det ∇T(x)|. (Flip the sign of the second return if you want the log determinant Jacobian of the inverse transformation.)

property reverse#

Get a JaxBijector for T⁻¹.

pangolin.jax_backend.bijectors.compose_jax_bijectors(bijectors, log_det_direction='forward')[source]#
Parameters:
  • bijectors (Sequence[JaxBijector])

  • log_det_direction (str)

Return type:

JaxBijector

pangolin.jax_backend.bijectors.exp()[source]#

Creates a JaxBijector instance that applies the exponential function.

Example

>>> import jax.numpy as jnp
>>> x = jnp.array(0.0)
>>> exp().forward(x)
Array(1., dtype=...)
pangolin.jax_backend.bijectors.log()[source]#

Creates a JaxBijector instance that applies the natural logarithm.

Example

>>> import jax.numpy as jnp
>>> x = jnp.array(1.0)
>>> log().forward(x)
Array(0., dtype=...)
pangolin.jax_backend.bijectors.logit()[source]#

Create a JaxBijector instance that applies the logit bijector y = logit(x). Commonly used to transform from [0,1] to reals.

Example

>>> import jax.numpy as jnp
>>> x = jnp.array(0.5)
>>> logit().forward(x)
Array(0., dtype=...)
pangolin.jax_backend.bijectors.inv_logit()[source]#

Create a JaxBijector instance that applies the inverse logit (expit/sigmoid).

Example

>>> import jax.numpy as jnp
>>> y = jnp.array(0.0)
>>> inv_logit().forward(y)
Array(0.5, dtype=...)
pangolin.jax_backend.bijectors.scaled_logit(a, b)[source]#

Create a JaxBijector instance that applies the scaled logit y = logit((x-a)/(b-a)). Commonly used to transform from [a,b] to reals.

pangolin.jax_backend.bijectors.fill_tril()[source]#

A JaxBijector instance that fills a lower-triangular matrix from a vector. Used to transform from real vectors to lower-triangular matrices.

Example

>>> import jax.numpy as jnp
>>> x = jnp.array([1., 2., 3.])
>>> fill_tril().forward(x)
Array([[1., 0.],
       [2., 3.]], dtype=...)
pangolin.jax_backend.bijectors.extract_tril()[source]#

A JaxBijector instance that extracts the lower-triangular part of a matrix. Commonly used to transform from triangular lower-triangular matrices to real vectors.

Example

>>> import jax.numpy as jnp
>>> X = jnp.array([[1., 0.], [2., 3.]])
>>> extract_tril().forward(X)
Array([1., 2., 3.], dtype=...)
pangolin.jax_backend.bijectors.exp_diagonal()[source]#

Create a JaxBijector instance that exponentiates the diagonal of a matrix. Commonly used to transform real lower-triangular matrices into Cholesky factors.

Example

>>> import jax.numpy as jnp
>>> X = jnp.array([[0., 0.], [2., 0.]])
>>> exp_diagonal().forward(X)
Array([[1., 0.],
       [2., 1.]], dtype=...)
pangolin.jax_backend.bijectors.log_diagonal()[source]#

Create a JaxBijector instance that takes the logarithm of the diagonal of a matrix. Commonly used to Cholesky factors into real lower-triangular matrices.

Example

>>> import jax.numpy as jnp
>>> X = jnp.array([[1., 0.], [2., 1.]])
>>> log_diagonal().forward(X)
Array([[0., 0.],
       [2., 0.]], dtype=...)
pangolin.jax_backend.bijectors.cholesky()[source]#

Create a JaxBijector instance that applies a Cholesky decomposition. Commonly used to transform from symmetric positive definite matrices into triangular matrices.

Example

>>> import jax.numpy as jnp
>>> X = jnp.array([[1., 0.], [0., 1.]])
>>> cholesky().forward(X)
Array([[1., 0.],
       [0., 1.]], dtype=...)
pangolin.jax_backend.bijectors.spd_to_unconstrained()[source]#

Returns A JaxBijector instance that transforms a symmetric positive definite into the space of unconstrained reals. Accomplished by (1) taking a Cholesky decomposition (2) taking the logarithm of the diagonal (3) extracting the lower-triangular entries.

Example

>>> import jax.numpy as jnp
>>> # Identity matrix is symmetric positive definite
>>> X = jnp.array([[1., 0.], [0., 1.]])
>>> spd_to_unconstrained().forward(X)
Array([0., 0., 0.], dtype=...)
>>> # Transform back to SPD matrix
>>> unconstrained_vec = jnp.array([0., 2., 0.])
>>> spd_to_unconstrained().inverse(unconstrained_vec)
Array([[1., 2.],
       [2., 5.]], dtype=float32)
>>> X = jnp.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
>>> Y = spd_to_unconstrained().forward(X)
>>> Y
Array([0., 0., 0., 0., 0., 0.], dtype=...)
>>> X_new = spd_to_unconstrained().inverse(Y)
>>> jnp.allclose(X, X_new)
Array(True, dtype=bool)
pangolin.jax_backend.bijectors.stick_breaking()[source]#

Create a JaxBijector instance that applies the stick-breaking transformation.

pangolin.jax_backend.bijectors.softmax_centered()[source]#

Create a JaxBijector instance for the sum-to-zero softmax mapping. Commonly used to transform from the unit simplex to unconstrained vectors

pangolin.jax_backend.bijectors.softmax_centered_ilr()[source]#

Create a JaxBijector for the Isometric Log-Ratio (ILR) simplex mapping borrowed from Stan.

Uses the orthonormal Helmert basis.

pangolin.jax_backend.bijectors.default_bijector_dict = {<class 'pangolin.ir.Beta'>: <function <lambda>>, <class 'pangolin.ir.Cauchy'>: None, <class 'pangolin.ir.Dirichlet'>: <function <lambda>>, <class 'pangolin.ir.Exponential'>: <function <lambda>>, <class 'pangolin.ir.Gamma'>: <function <lambda>>, <class 'pangolin.ir.Lognormal'>: <function <lambda>>, <class 'pangolin.ir.MultiNormal'>: None, <class 'pangolin.ir.Normal'>: None, <class 'pangolin.ir.NormalPrec'>: None, <class 'pangolin.ir.StudentT'>: None, <class 'pangolin.ir.Uniform'>: <function <lambda>>, <class 'pangolin.ir.Wishart'>: <function <lambda>>}#

A reasonable default bijector dictionary:

It is easy to provide alternative bijectors: Just create a new dictionary, with functions that create new JaxBijector instances (if you want). (View source to see how this is defined.)