jax_backend.bijectors#
This should be a dictionary mapping primitive random Op classes to JaxBijector instances. This is commonly used to transform constrained random Op (like Dirichlet or Uniform) to an unconstrained space to make gradient-based inference easier. This dictionary should cover all the base random, excluding Composite, VMap and Scan.
- class pangolin.jax_backend.bijectors.JaxBijector(forward, inverse, log_det_jac)[source]#
The idea is that if
P(X)is some density andY=T(X)is a diffeomorphism, thenP(Y=y) = P(X=T⁻¹(y)) × |det ∇T⁻¹(y)|- Parameters:
forward – jax function implementing forward transformation given
xinverse – jax function implementing inverse transformation given
ylog_det_jax – jax function implementing the log determinant of the Jacobian of the forward transformation given both
xandy(may use either as convenient)
- log_det_jac(x, y)[source]#
Computes
log |det ∇T(x)| = -log |det ∇T⁻¹(y)|. May use eitherxoryas convenient.
- inverse_and_log_det_jac(y)[source]#
Computes
T⁻¹(y)andlog |det ∇T(x)|. (Flip the sign of the second return if you want the log determinant Jacobian of the inverse transformation.)
- property reverse#
Get a
JaxBijectorforT⁻¹.
- pangolin.jax_backend.bijectors.compose_jax_bijectors(bijectors, log_det_direction='forward')[source]#
- Parameters:
bijectors (Sequence[JaxBijector])
log_det_direction (str)
- Return type:
- pangolin.jax_backend.bijectors.exp()[source]#
Creates a
JaxBijectorinstance that applies the exponential function.Example
>>> import jax.numpy as jnp >>> x = jnp.array(0.0) >>> exp().forward(x) Array(1., dtype=...)
- pangolin.jax_backend.bijectors.log()[source]#
Creates a
JaxBijectorinstance that applies the natural logarithm.Example
>>> import jax.numpy as jnp >>> x = jnp.array(1.0) >>> log().forward(x) Array(0., dtype=...)
- pangolin.jax_backend.bijectors.logit()[source]#
Create a
JaxBijectorinstance that applies the logit bijectory = logit(x). Commonly used to transform from [0,1] to reals.Example
>>> import jax.numpy as jnp >>> x = jnp.array(0.5) >>> logit().forward(x) Array(0., dtype=...)
- pangolin.jax_backend.bijectors.inv_logit()[source]#
Create a
JaxBijectorinstance that applies the inverse logit (expit/sigmoid).Example
>>> import jax.numpy as jnp >>> y = jnp.array(0.0) >>> inv_logit().forward(y) Array(0.5, dtype=...)
- pangolin.jax_backend.bijectors.scaled_logit(a, b)[source]#
Create a
JaxBijectorinstance that applies the scaled logity = logit((x-a)/(b-a)). Commonly used to transform from [a,b] to reals.
- pangolin.jax_backend.bijectors.fill_tril()[source]#
A
JaxBijectorinstance that fills a lower-triangular matrix from a vector. Used to transform from real vectors to lower-triangular matrices.Example
>>> import jax.numpy as jnp >>> x = jnp.array([1., 2., 3.]) >>> fill_tril().forward(x) Array([[1., 0.], [2., 3.]], dtype=...)
- pangolin.jax_backend.bijectors.extract_tril()[source]#
A
JaxBijectorinstance that extracts the lower-triangular part of a matrix. Commonly used to transform from triangular lower-triangular matrices to real vectors.Example
>>> import jax.numpy as jnp >>> X = jnp.array([[1., 0.], [2., 3.]]) >>> extract_tril().forward(X) Array([1., 2., 3.], dtype=...)
- pangolin.jax_backend.bijectors.exp_diagonal()[source]#
Create a
JaxBijectorinstance that exponentiates the diagonal of a matrix. Commonly used to transform real lower-triangular matrices into Cholesky factors.Example
>>> import jax.numpy as jnp >>> X = jnp.array([[0., 0.], [2., 0.]]) >>> exp_diagonal().forward(X) Array([[1., 0.], [2., 1.]], dtype=...)
- pangolin.jax_backend.bijectors.log_diagonal()[source]#
Create a
JaxBijectorinstance that takes the logarithm of the diagonal of a matrix. Commonly used to Cholesky factors into real lower-triangular matrices.Example
>>> import jax.numpy as jnp >>> X = jnp.array([[1., 0.], [2., 1.]]) >>> log_diagonal().forward(X) Array([[0., 0.], [2., 0.]], dtype=...)
- pangolin.jax_backend.bijectors.cholesky()[source]#
Create a
JaxBijectorinstance that applies a Cholesky decomposition. Commonly used to transform from symmetric positive definite matrices into triangular matrices.Example
>>> import jax.numpy as jnp >>> X = jnp.array([[1., 0.], [0., 1.]]) >>> cholesky().forward(X) Array([[1., 0.], [0., 1.]], dtype=...)
- pangolin.jax_backend.bijectors.spd_to_unconstrained()[source]#
Returns A
JaxBijectorinstance that transforms a symmetric positive definite into the space of unconstrained reals. Accomplished by (1) taking a Cholesky decomposition (2) taking the logarithm of the diagonal (3) extracting the lower-triangular entries.Example
>>> import jax.numpy as jnp >>> # Identity matrix is symmetric positive definite >>> X = jnp.array([[1., 0.], [0., 1.]]) >>> spd_to_unconstrained().forward(X) Array([0., 0., 0.], dtype=...)
>>> # Transform back to SPD matrix >>> unconstrained_vec = jnp.array([0., 2., 0.]) >>> spd_to_unconstrained().inverse(unconstrained_vec) Array([[1., 2.], [2., 5.]], dtype=float32)
>>> X = jnp.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) >>> Y = spd_to_unconstrained().forward(X) >>> Y Array([0., 0., 0., 0., 0., 0.], dtype=...) >>> X_new = spd_to_unconstrained().inverse(Y) >>> jnp.allclose(X, X_new) Array(True, dtype=bool)
- pangolin.jax_backend.bijectors.stick_breaking()[source]#
Create a
JaxBijectorinstance that applies the stick-breaking transformation.
- pangolin.jax_backend.bijectors.softmax_centered()[source]#
Create a
JaxBijectorinstance for the sum-to-zero softmax mapping. Commonly used to transform from the unit simplex to unconstrained vectors
- pangolin.jax_backend.bijectors.softmax_centered_ilr()[source]#
Create a
JaxBijectorfor the Isometric Log-Ratio (ILR) simplex mapping borrowed from Stan.Uses the orthonormal Helmert basis.
See also
- pangolin.jax_backend.bijectors.default_bijector_dict = {<class 'pangolin.ir.Beta'>: <function <lambda>>, <class 'pangolin.ir.Cauchy'>: None, <class 'pangolin.ir.Dirichlet'>: <function <lambda>>, <class 'pangolin.ir.Exponential'>: <function <lambda>>, <class 'pangolin.ir.Gamma'>: <function <lambda>>, <class 'pangolin.ir.Lognormal'>: <function <lambda>>, <class 'pangolin.ir.MultiNormal'>: None, <class 'pangolin.ir.Normal'>: None, <class 'pangolin.ir.NormalPrec'>: None, <class 'pangolin.ir.StudentT'>: None, <class 'pangolin.ir.Uniform'>: <function <lambda>>, <class 'pangolin.ir.Wishart'>: <function <lambda>>}#
A reasonable default bijector dictionary:
Opclassbijector factory
None
None
None
None
None
It is easy to provide alternative bijectors: Just create a new dictionary, with functions that create new
JaxBijectorinstances (if you want). (View source to see how this is defined.)