util Module#

This submodule contains internal utility functions. None of these functions import or use any other parts of Pangolin. End-users of Pangolin are not typically expected to use these functions directly.

pangolin.util.comma_separated(stuff, fun=None, parens=True, spaces=False)[source]#

convenience function for turning seqences into comma separated strings.

Parameters:
  • stuff (Sequence) – list to print

  • fun (Callable | None) – function to apply to each item before printing

  • parens (bool) – do you want parentheses?

  • spaces (bool) – do you want spaces between items?

Returns:

string containing all items with commas between them

Return type:

str

Examples

>>> comma_separated(['a', 'b', 'c'])
'(a,b,c)'
>>> comma_separated(['a', 'b', 'c'], lambda s: s + "0")
'(a0,b0,c0)'
>>> comma_separated(['a', 'b', 'c'], parens=False)
'a,b,c'
>>> comma_separated(['a', 'b', 'c'], lambda s: s + "0", parens=False)
'a0,b0,c0'
class pangolin.util.VarNames[source]#

Convenience class to automatically give unique string names to objects.

This class has no parameters.

Examples

>>> var_names = VarNames()
>>> var_names['bob']
'v0v'
>>> var_names['alice']
'v1v'
>>> var_names['bob']
'v0v'
>>> var_names['carlos']
'v2v'
>>> var_names['alice']
'v1v'
>>> var_names['bob']
'v0v'
pangolin.util.num_not_none(*args)[source]#

how many items are not None?

Parameters:

args (Any) – any number of arguments of any times

Returns:

number of items that are non-None

Return type:

int

Examples

>>> num_not_none(1)
1
>>> num_not_none(1,2)
2
>>> num_not_none(1,None)
1
>>> num_not_none(None,1,None)
1
>>> num_not_none(None,"cat","dog",77,None)
3
pangolin.util.all_unique(a)[source]#

Are all items in a sequence unique?

Parameters:

a (Sequence) – some sequence

Returns:

True if all items are unique, False otherwise.

Return type:

bool

pangolin.util.intersects(A, B)[source]#

Check if two collections have any common elements.

Parameters:
  • A (Iterable) – First collection of elements.

  • B (Iterable) – Second collection of elements.

Returns:

True if A and B share at least one element, False otherwise.

Return type:

bool

Examples

>>> intersects([1, 2, 3], [3, 4, 5])
True
>>> intersects(['apple', 'banana'], ['orange', 'grape'])
False
class pangolin.util.WriteOnceDict[source]#

A dict where you can’t overwrite entries. Takes no parameters.

Examples

>>> d = WriteOnceDict()
>>> d['a'] = 2
>>> d['a']
2
>>> d['b'] = 3
>>> d['b']
3
>>> d['a'] = 1
Traceback (most recent call last):
ValueError: Cannot overwrite existing key a in WriteOnceDict
class pangolin.util.WriteOnceDefaultDict(default_factory)[source]#

A dict where you can’t overwrite entries. If you access a value with no key, default_factory is used to create a key/value pair automatically. Once you have written to or accessed a key, it cannot be changed.

Parameters:

default_factory (Callable) – function giving values for non-existent entries

Examples

>>> d = WriteOnceDefaultDict(lambda s : len(s))
>>> d['cat'] = 5
>>> d['bob']
3
>>> d['cat']
5
>>> d['bob'] = 5 # raises an error (as expected)
Traceback (most recent call last):
ValueError: Cannot overwrite existing key bob in WriteOnceDefaultDict
pangolin.util.tree_map_with_none_as_leaf(f, tree, *rest)[source]#

Call jax.tree_util.tree_map using a special is_leaf function that preserves None. This is exactly like jax.tree.map except that nodes with None are treated as leaves rather than as nonexistant.

Parameters:
  • f (Callable) – function that takes 1 + len(rest) arguments, to be applied to corresponding leaves

  • tree (PyTree) – a pytree to be mapped over, with each leaf providing the first positional argument to f

  • rest (PyTree) – additional pytrees to be mapped over, each of which must either have the same structure as tree or have tree as a prefix.

Examples

>>> f = lambda x: 0 if x is None else len(x)
>>> tree = ["cat", ("bear", None)]
>>> jax.tree_util.tree_map(f, tree)
[3, (4, None)]
>>> tree_map_with_none_as_leaf(f, tree)
[3, (4, 0)]
>>> f2 = lambda leaf, seq: f(leaf) + len(seq)
>>> tree1 = ["cat", ("bear", None)]
>>> tree2 = [[1,2,3,4,5], ([], [1,2,3,4,5])]
>>> tree_map_with_none_as_leaf(f2, tree1, tree2)
[8, (4, 5)]
pangolin.util.tree_map_preserve_none(f, tree, *rest)[source]#
Parameters:
  • f (Callable)

  • tree (PyTree)

  • rest (PyTree)

pangolin.util.tree_structure_with_none_as_lead(pytree)[source]#
Parameters:

pytree (PyTree)

pangolin.util.tree_flatten_with_none_as_leaf(x)[source]#
Parameters:

x (PyTree)

pangolin.util.same(x, y)[source]#

assert that x are either: * both None * equal floats * equal arrays (with equal shapes)

pangolin.util.same_tree(x, y, is_leaf=None)[source]#

Check that x and y have same tree structure (including None) and that all leaves are equal. Arrays are equal regardless of if they come from regular numpy or jax numpy and ignore types. (e.g. numpy.array([2,3])) is considered equal to jax.numpy.array([2.0,3.0]).)

Parameters:
  • x (PyTree)

  • y (PyTree)

  • is_leaf (Callable | None)

pangolin.util.map_inside_tree(f, tree)[source]#

Map a function over the leading axis for all leaf nodes inside a tree. If a None value is encountered at any input, it is presented to each function unchanged. If a None appears as an output, it is presented as an output unchanged.

Examples

>>> def f(t):
...     a, (b, c) = t
...     return (a + b, a * b), c
>>> tree = np.array([1, 2]), (np.array([3, 4]), np.array([5, 6]))
>>> map_inside_tree(f, tree)
((array([4, 6]), array([3, 8])), array([5, 6]))
>>> def f(t):
...     a, (b, c) = t
...     return (a + b, c), None
>>> tree = np.array([1, 2]), (np.array([3, 4]), None)
>>> map_inside_tree(f, tree)
((array([4, 6]), None), None)
Parameters:
  • f (Callable)

  • tree (PyTree)

pangolin.util.assert_all_leaves_instance_of(tree, type, is_leaf=None)[source]#
Parameters:
  • tree (PyTree)

  • type (type)

  • is_leaf (Callable | None)

pangolin.util.assert_all_leaves_instance_of_with_none(tree, type)[source]#
Parameters:
  • tree (PyTree)

  • type (type)

pangolin.util.assert_is_sequence_of(seq, type)[source]#
Parameters:
  • seq (Sequence)

  • type (type)

pangolin.util.num2str(id)[source]#
Parameters:

id (int)

Return type:

str

pangolin.util.is_shape_tuple(a)[source]#
pangolin.util.tree_map_recurse_at_leaf(f, tree, *remaining_trees, is_leaf=None)[source]#

Applies a function f to corresponding leaves of tree and *remaining_trees.

This function implements a “recursive broadcast” behavior. If tree has a leaf at a path where any of remaining_trees has a subtree, that tree leaf is “broadcast” to all leaves of the corresponding subtree in remaining_trees.

Parameters:
  • f (Callable[[...], Any]) – The function to apply to the leaves. Its first argument will be a leaf from tree, and subsequent arguments will be corresponding leaves from remaining_trees. If broadcasting occurs, the first argument (leaf_from_tree) will be fixed for all leaves within the broadcasted subtree.

  • tree (PyTree) – The primary PyTree. Its leaves will trigger the broadcasting behavior. It is expected to be a “prefix” or “smaller” structure compared to remaining_trees at corresponding paths.

  • remaining_trees (PyTree) – One or more additional PyTrees to map over. These are expected to be “superset” structures relative to tree at corresponding paths.

  • is_leaf (Callable[[Any], bool] | None) – An optional callable that takes a single argument (a node in a PyTree) and returns True if that node should be considered a leaf (i.e., tree_map should not recurse into it), and False otherwise. This function is applied to both the outer and inner jax.tree_map calls. If None, JAX’s default leaf detection is used.

Returns:

PyTree – A new PyTree with the results of f application. Its structure will match that of the first PyTree in remaining_trees (or tree if remaining_trees is empty).

Return type:

PyTree

Notes

  • If tree has a subtree at a path where one of remaining_trees has a leaf, jax.tree_map will raise a ValueError due to a structural mismatch.

  • This function leverages nested jax.tree_map calls for conciseness.

Examples

>>> # No broadcasting (standard tree_map behavior)
>>> tree1 = {'c': 5, 'd': 6}
>>> tree2 = {'c': 2, 'd': 3}
>>> tree_map_recurse_at_leaf(lambda l1, l2: l1 * l2, tree1, tree2)
{'c': 10, 'd': 18}
>>> # Simple broadcasting (multiply first leaf by second)
>>> tree1 = {'a': 10, 'b': 20}
>>> tree2 = {'a': {'x': 1, 'y': 2}, 'b': 3}
>>> tree_map_recurse_at_leaf(lambda l1, l2: l1 * l2, tree1, tree2)
{'a': {'x': 10, 'y': 20}, 'b': 60}
>>> # Custom is_leaf (treating None as a leaf)
>>> tree1 = {'data': 100, 'config': None}
>>> tree2 = {'data': {'val': 1, 'factor': 2}, 'config': 'default'}
>>> tree_map_recurse_at_leaf(
...     lambda l1, l2: f"{l1}_{l2}" if l1 is None else l1 * l2,
...     tree1, tree2, is_leaf=lambda x: x is None
... )
{'config': 'None_default', 'data': {'factor': 200, 'val': 100}}
pangolin.util.tree_map_recurse_at_leaf_with_none_as_leaf(f, tree, *remaining_trees)[source]#

Examples

>>> pytree1 = (0,[1,2])
>>> pytree2 = ("dog", ["cat", None])
>>> tree_map_recurse_at_leaf_with_none_as_leaf(lambda a,b: a, pytree1, pytree2)
(0, [1, 2])
>>> pytree1 = 3
>>> pytree2 = {"cat": 0, "dog": 2}
>>> tree_map_recurse_at_leaf_with_none_as_leaf(lambda a,b: a, pytree1, pytree2)
{'cat': 3, 'dog': 3}
Parameters:
  • f (Callable)

  • tree (PyTree)

  • remaining_trees (PyTree)

pangolin.util.flatten_fun(f, *args, is_leaf=None)[source]#

get a new function that takes a single input (which is a list)

Parameters:
  • f (Callable)

  • args (Any)

  • is_leaf (Callable[[Any], bool] | None)

pangolin.util.dual_flatten(pytree1, pytree2)[source]#
Parameters:
  • pytree1 (PyTree) – first pytree

  • pytree2 (PyTree) – second pytree, must have same structure as pytree1 or at least have the structure of pytree1 as a prefix

Return type:

PyTree

Examples

>>> pytree1 = (0,[1,2])
>>> pytree2 = ("dog",["cat", "owl"])
>>> dual_flatten(pytree1, pytree2)
([0, 1, 2], ['dog', 'cat', 'owl'])
>>> pytree1 = (None,[1,2])
>>> pytree2 = ("dog", ["cat", None])
>>> dual_flatten(pytree1, pytree2)
([None, 1, 2], ['dog', 'cat', None])
>>> pytree1 = (None, 3)
>>> pytree2 = ("dog", ["cat", None])
>>> dual_flatten(pytree1, pytree2)
([None, 3, 3], ['dog', 'cat', None])
pangolin.util.short_pytree_string(treedef)[source]#

Get a string for a JAX PyTreeDef without printing PyTreeDef and scaring the noobs

pangolin.util.assimilate_vals(vars, vals)[source]#

convert vals to a pytree of arrays with the same shape as vars The purpose of this is when users might provide lists / tuples that should be auto-casted to a pytree of arrays. (Without vars it would be impossible to tell a list of arrays of the same length from a big array with one more dimension.)

Parameters:
  • vars (PyTree)

  • vals (PyTree)

Return type:

PyTree[jnp.ndarray]

pangolin.util.flatten_args(vars, given_vars=None, given_vals=None)[source]#
pangolin.util.nth_index(lst, item, n)[source]#
pangolin.util.first_index(lst, condition)[source]#
pangolin.util.swapped_list(lst, i, j)[source]#
pangolin.util.first(lst, cond, default=None)[source]#

get first element of lst satisfying cond or if none then default

pangolin.util.reverse_dict(d)[source]#
Parameters:

d (dict)

pangolin.util.replace_in_sequence(seq, i, new)[source]#
pangolin.util.camel_case_to_snake_case(name)[source]#
pangolin.util.most_specific_class(*args, base_classes=())[source]#
pangolin.util.is_numeric_numpy_array(x)[source]#
pangolin.util.unzip(source, strict=False)[source]#

Reverses zip

Parameters:

source (Sequence[tuple])

pangolin.util.tree_allclose(a, b, **kwargs)[source]#

Checks if two PyTrees are structurally identical and all leaves are close.

This function first verifies that the two PyTrees have the exact same structure. If they do, it then compares each corresponding leaf node (array) using np.allclose and returns True if all leaves are close, False otherwise.

Parameters:
  • a (PyTree) – The first PyTree to compare.

  • b (PyTree) – The second PyTree to compare.

  • kwargs (Any) – Additional keyword arguments to be passed directly to np.allclose`. Common arguments include ``rtol (relative tolerance) and atol (absolute tolerance).

Returns:

True if the structures match and all corresponding leaves are close. False if the structures match but any leaves are not close.

Raises:

ValueError – If the PyTree structures of a and b are not identical.

Return type:

bool

Examples

>>> tree1 = {'a': jnp.array([1.0, 2.0]), 'b': (jnp.array(3.0),)}
>>> tree2 = {'a': jnp.array([1.000001, 2.0]), 'b': (jnp.array(3.0),)}
>>> tree_allclose(tree1, tree2, atol=1e-5)
True
>>> tree3 = {'a': jnp.array([1.0, 2.5]), 'b': (jnp.array(3.0),)}
>>> tree_allclose(tree1, tree3)
False
>>> tree4 = {'a': jnp.array([1.0, 2.0])} # Different structure
>>> try:
...     tree_allclose(tree1, tree4)
... except ValueError as e:
...     print(e)
PyTree structures do not match.
 a structure: PyTreeDef({'a': *, 'b': (*,)})
 b structure: PyTreeDef({'a': *})
pangolin.util.get_positional_args(target_func, *args, **kwargs)[source]#

Transforms a mix of positional and keyword arguments into a list of purely positional arguments for a target function.

This function inspects the signature of target_func to correctly map and order the provided arguments. It ensures that the resulting list of positional arguments, when passed to target_func, will produce the identical outcome as the original call with mixed arguments.

Raises a ValueError if target_func contains any keyword-only arguments, as these cannot be represented in a purely positional call. Raises a TypeError if the provided arguments (*args, **kwargs) are invalid for target_func (e.g., missing required arguments, unexpected arguments).

Parameters:
  • target_func (Callable) – The function whose signature will be used to transform and order the arguments.

  • args (Any) – Positional arguments to be transformed.

  • kwargs (Any) – Keyword arguments to be transformed.

Returns:

A list of arguments in the correct positional order for target_func.

Raises:
  • ValueError – If target_func has keyword-only arguments.

  • TypeError – If the provided *args and **kwargs do not match the signature of target_func.

Examples

>>> def add_five_numbers(a, b, c, d, e):
...     return a + b + c + d + e
>>> # Case 1: All arguments provided positionally
>>> get_positional_args(add_five_numbers, 1, 2, 3, 4, 5)
[1, 2, 3, 4, 5]
>>> # Case 2: Mixed positional and keyword arguments
>>> get_positional_args(add_five_numbers, 1, 2, 3, e=5, d=4)
[1, 2, 3, 4, 5]
>>> # Case 3: All arguments provided as keywords
>>> get_positional_args(add_five_numbers, a=10, b=20, c=30, d=40, e=50)
[10, 20, 30, 40, 50]
>>> # Case 4: Function with default arguments
>>> def greet(name, greeting="Hello"):
...     return f"{greeting}, {name}!"
>>> get_positional_args(greet, "Alice")
['Alice', 'Hello']
>>> get_positional_args(greet, "Bob", greeting="Hi")
['Bob', 'Hi']
>>> get_positional_args(greet, name="Charlie", greeting="Greetings")
['Charlie', 'Greetings']
>>> # Case 5: Verification that the transformed arguments yield the same result
>>> def multiply(x, y, z):
...     return x * y * z
>>> original_args = (2,)
>>> original_kwargs = {'z': 5, 'y': 3}
>>> transformed = get_positional_args(multiply, *original_args, **original_kwargs)
>>> multiply(*original_args, **original_kwargs) == multiply(*transformed)
True
>>> # Case 6: Attempting to transform a function with keyword-only arguments (raises ValueError)
>>> def func_with_kw_only(a, b, *, kw_only_arg, default_kw_only=10):
...     pass
>>> try:
...     get_positional_args(func_with_kw_only, 1, 2, kw_only_arg=3)
... except ValueError as e:
...     print(e)
Function 'func_with_kw_only' has keyword-only arguments (kw_only_arg), which is not allowed by this transformer.
>>> # Case 7: Missing a required argument (raises TypeError)
>>> try:
...     get_positional_args(add_five_numbers, 1, 2, 3)
... except TypeError as e:
...     # Use '...' to match any varying parts of the error message
...     # and check for a key phrase.
...     assert "missing a required argument" in str(e)
...     print(e.__class__.__name__ + ": " + str(e))
TypeError: ...
>>> # Case 8: Providing an unexpected argument (raises TypeError)
>>> try:
...     get_positional_args(add_five_numbers, 1, 2, 3, 4, 5, extra_arg=6)
... except TypeError as e:
...     print(e.__class__.__name__ + ": " + str(e))
TypeError: ... an unexpected keyword argument 'extra_arg'