util Module#
This submodule contains internal utility functions. None of these functions import or use any other parts of Pangolin. End-users of Pangolin are not typically expected to use these functions directly.
- pangolin.util.comma_separated(stuff, fun=None, parens=True, spaces=False)[source]#
convenience function for turning seqences into comma separated strings.
- Parameters:
stuff (Sequence) – list to print
fun (Callable | None) – function to apply to each item before printing
parens (bool) – do you want parentheses?
spaces (bool) – do you want spaces between items?
- Returns:
string containing all items with commas between them
- Return type:
str
Examples
>>> comma_separated(['a', 'b', 'c']) '(a,b,c)' >>> comma_separated(['a', 'b', 'c'], lambda s: s + "0") '(a0,b0,c0)' >>> comma_separated(['a', 'b', 'c'], parens=False) 'a,b,c' >>> comma_separated(['a', 'b', 'c'], lambda s: s + "0", parens=False) 'a0,b0,c0'
- class pangolin.util.VarNames[source]#
Convenience class to automatically give unique string names to objects.
This class has no parameters.
Examples
>>> var_names = VarNames() >>> var_names['bob'] 'v0v' >>> var_names['alice'] 'v1v' >>> var_names['bob'] 'v0v' >>> var_names['carlos'] 'v2v' >>> var_names['alice'] 'v1v' >>> var_names['bob'] 'v0v'
- pangolin.util.num_not_none(*args)[source]#
how many items are not
None?- Parameters:
args (Any) – any number of arguments of any times
- Returns:
number of items that are non-
None- Return type:
int
Examples
>>> num_not_none(1) 1 >>> num_not_none(1,2) 2 >>> num_not_none(1,None) 1 >>> num_not_none(None,1,None) 1 >>> num_not_none(None,"cat","dog",77,None) 3
- pangolin.util.all_unique(a)[source]#
Are all items in a sequence unique?
- Parameters:
a (Sequence) – some sequence
- Returns:
True if all items are unique, False otherwise.
- Return type:
bool
- pangolin.util.intersects(A, B)[source]#
Check if two collections have any common elements.
- Parameters:
A (Iterable) – First collection of elements.
B (Iterable) – Second collection of elements.
- Returns:
True if
AandBshare at least one element, False otherwise.- Return type:
bool
Examples
>>> intersects([1, 2, 3], [3, 4, 5]) True
>>> intersects(['apple', 'banana'], ['orange', 'grape']) False
- class pangolin.util.WriteOnceDict[source]#
A dict where you can’t overwrite entries. Takes no parameters.
Examples
>>> d = WriteOnceDict() >>> d['a'] = 2 >>> d['a'] 2 >>> d['b'] = 3 >>> d['b'] 3 >>> d['a'] = 1 Traceback (most recent call last): ValueError: Cannot overwrite existing key a in WriteOnceDict
- class pangolin.util.WriteOnceDefaultDict(default_factory)[source]#
A dict where you can’t overwrite entries. If you access a value with no key,
default_factoryis used to create a key/value pair automatically. Once you have written to or accessed a key, it cannot be changed.- Parameters:
default_factory (Callable) – function giving values for non-existent entries
Examples
>>> d = WriteOnceDefaultDict(lambda s : len(s)) >>> d['cat'] = 5 >>> d['bob'] 3 >>> d['cat'] 5 >>> d['bob'] = 5 # raises an error (as expected) Traceback (most recent call last): ValueError: Cannot overwrite existing key bob in WriteOnceDefaultDict
- pangolin.util.tree_map_with_none_as_leaf(f, tree, *rest)[source]#
Call jax.tree_util.tree_map using a special is_leaf function that preserves None. This is exactly like jax.tree.map except that nodes with None are treated as leaves rather than as nonexistant.
- Parameters:
f (Callable) – function that takes
1 + len(rest)arguments, to be applied to corresponding leavestree (PyTree) – a pytree to be mapped over, with each leaf providing the first positional argument to
frest (PyTree) – additional pytrees to be mapped over, each of which must either have the same structure as
treeor havetreeas a prefix.
Examples
>>> f = lambda x: 0 if x is None else len(x) >>> tree = ["cat", ("bear", None)] >>> jax.tree_util.tree_map(f, tree) [3, (4, None)] >>> tree_map_with_none_as_leaf(f, tree) [3, (4, 0)]
>>> f2 = lambda leaf, seq: f(leaf) + len(seq) >>> tree1 = ["cat", ("bear", None)] >>> tree2 = [[1,2,3,4,5], ([], [1,2,3,4,5])] >>> tree_map_with_none_as_leaf(f2, tree1, tree2) [8, (4, 5)]
- pangolin.util.tree_map_preserve_none(f, tree, *rest)[source]#
- Parameters:
f (Callable)
tree (PyTree)
rest (PyTree)
- pangolin.util.same(x, y)[source]#
assert that x are either: * both None * equal floats * equal arrays (with equal shapes)
- pangolin.util.same_tree(x, y, is_leaf=None)[source]#
Check that x and y have same tree structure (including None) and that all leaves are equal. Arrays are equal regardless of if they come from regular numpy or jax numpy and ignore types. (e.g. numpy.array([2,3])) is considered equal to jax.numpy.array([2.0,3.0]).)
- Parameters:
x (PyTree)
y (PyTree)
is_leaf (Callable | None)
- pangolin.util.map_inside_tree(f, tree)[source]#
Map a function over the leading axis for all leaf nodes inside a tree. If a None value is encountered at any input, it is presented to each function unchanged. If a None appears as an output, it is presented as an output unchanged.
Examples
>>> def f(t): ... a, (b, c) = t ... return (a + b, a * b), c >>> tree = np.array([1, 2]), (np.array([3, 4]), np.array([5, 6])) >>> map_inside_tree(f, tree) ((array([4, 6]), array([3, 8])), array([5, 6]))
>>> def f(t): ... a, (b, c) = t ... return (a + b, c), None >>> tree = np.array([1, 2]), (np.array([3, 4]), None) >>> map_inside_tree(f, tree) ((array([4, 6]), None), None)
- Parameters:
f (Callable)
tree (PyTree)
- pangolin.util.assert_all_leaves_instance_of(tree, type, is_leaf=None)[source]#
- Parameters:
tree (PyTree)
type (type)
is_leaf (Callable | None)
- pangolin.util.assert_all_leaves_instance_of_with_none(tree, type)[source]#
- Parameters:
tree (PyTree)
type (type)
- pangolin.util.tree_map_recurse_at_leaf(f, tree, *remaining_trees, is_leaf=None)[source]#
Applies a function
fto corresponding leaves oftreeand*remaining_trees.This function implements a “recursive broadcast” behavior. If
treehas a leaf at a path where any ofremaining_treeshas a subtree, thattreeleaf is “broadcast” to all leaves of the corresponding subtree inremaining_trees.- Parameters:
f (Callable[[...], Any]) – The function to apply to the leaves. Its first argument will be a leaf from
tree, and subsequent arguments will be corresponding leaves fromremaining_trees. If broadcasting occurs, the first argument (leaf_from_tree) will be fixed for all leaves within the broadcasted subtree.tree (PyTree) – The primary PyTree. Its leaves will trigger the broadcasting behavior. It is expected to be a “prefix” or “smaller” structure compared to
remaining_treesat corresponding paths.remaining_trees (PyTree) – One or more additional PyTrees to map over. These are expected to be “superset” structures relative to
treeat corresponding paths.is_leaf (Callable[[Any], bool] | None) – An optional callable that takes a single argument (a node in a PyTree) and returns
Trueif that node should be considered a leaf (i.e.,tree_mapshould not recurse into it), andFalseotherwise. This function is applied to both the outer and innerjax.tree_mapcalls. IfNone, JAX’s default leaf detection is used.
- Returns:
PyTree – A new PyTree with the results of
fapplication. Its structure will match that of the first PyTree inremaining_trees(ortreeifremaining_treesis empty).- Return type:
PyTree
Notes
If
treehas a subtree at a path where one ofremaining_treeshas a leaf,jax.tree_mapwill raise aValueErrordue to a structural mismatch.This function leverages nested
jax.tree_mapcalls for conciseness.
Examples
>>> # No broadcasting (standard tree_map behavior) >>> tree1 = {'c': 5, 'd': 6} >>> tree2 = {'c': 2, 'd': 3} >>> tree_map_recurse_at_leaf(lambda l1, l2: l1 * l2, tree1, tree2) {'c': 10, 'd': 18}
>>> # Simple broadcasting (multiply first leaf by second) >>> tree1 = {'a': 10, 'b': 20} >>> tree2 = {'a': {'x': 1, 'y': 2}, 'b': 3} >>> tree_map_recurse_at_leaf(lambda l1, l2: l1 * l2, tree1, tree2) {'a': {'x': 10, 'y': 20}, 'b': 60}
>>> # Custom is_leaf (treating None as a leaf) >>> tree1 = {'data': 100, 'config': None} >>> tree2 = {'data': {'val': 1, 'factor': 2}, 'config': 'default'} >>> tree_map_recurse_at_leaf( ... lambda l1, l2: f"{l1}_{l2}" if l1 is None else l1 * l2, ... tree1, tree2, is_leaf=lambda x: x is None ... ) {'config': 'None_default', 'data': {'factor': 200, 'val': 100}}
- pangolin.util.tree_map_recurse_at_leaf_with_none_as_leaf(f, tree, *remaining_trees)[source]#
Examples
>>> pytree1 = (0,[1,2]) >>> pytree2 = ("dog", ["cat", None]) >>> tree_map_recurse_at_leaf_with_none_as_leaf(lambda a,b: a, pytree1, pytree2) (0, [1, 2])
>>> pytree1 = 3 >>> pytree2 = {"cat": 0, "dog": 2} >>> tree_map_recurse_at_leaf_with_none_as_leaf(lambda a,b: a, pytree1, pytree2) {'cat': 3, 'dog': 3}
- Parameters:
f (Callable)
tree (PyTree)
remaining_trees (PyTree)
- pangolin.util.flatten_fun(f, *args, is_leaf=None)[source]#
get a new function that takes a single input (which is a list)
- Parameters:
f (Callable)
args (Any)
is_leaf (Callable[[Any], bool] | None)
- pangolin.util.dual_flatten(pytree1, pytree2)[source]#
- Parameters:
pytree1 (PyTree) – first pytree
pytree2 (PyTree) – second pytree, must have same structure as
pytree1or at least have the structure ofpytree1as a prefix
- Return type:
PyTree
Examples
>>> pytree1 = (0,[1,2]) >>> pytree2 = ("dog",["cat", "owl"]) >>> dual_flatten(pytree1, pytree2) ([0, 1, 2], ['dog', 'cat', 'owl'])
>>> pytree1 = (None,[1,2]) >>> pytree2 = ("dog", ["cat", None]) >>> dual_flatten(pytree1, pytree2) ([None, 1, 2], ['dog', 'cat', None])
>>> pytree1 = (None, 3) >>> pytree2 = ("dog", ["cat", None]) >>> dual_flatten(pytree1, pytree2) ([None, 3, 3], ['dog', 'cat', None])
- pangolin.util.short_pytree_string(treedef)[source]#
Get a string for a JAX PyTreeDef without printing PyTreeDef and scaring the noobs
- pangolin.util.assimilate_vals(vars, vals)[source]#
convert
valsto a pytree of arrays with the same shape asvarsThe purpose of this is when users might provide lists / tuples that should be auto-casted to a pytree of arrays. (Withoutvarsit would be impossible to tell a list of arrays of the same length from a big array with one more dimension.)- Parameters:
vars (PyTree)
vals (PyTree)
- Return type:
PyTree[jnp.ndarray]
- pangolin.util.first(lst, cond, default=None)[source]#
get first element of
lstsatisfyingcondor if none thendefault
- pangolin.util.unzip(source, strict=False)[source]#
Reverses zip
- Parameters:
source (Sequence[tuple])
- pangolin.util.tree_allclose(a, b, **kwargs)[source]#
Checks if two PyTrees are structurally identical and all leaves are close.
This function first verifies that the two PyTrees have the exact same structure. If they do, it then compares each corresponding leaf node (array) using
np.allcloseand returnsTrueif all leaves are close,Falseotherwise.- Parameters:
a (PyTree) – The first PyTree to compare.
b (PyTree) – The second PyTree to compare.
kwargs (Any) – Additional keyword arguments to be passed directly to
np.allclose`. Common arguments include ``rtol(relative tolerance) andatol(absolute tolerance).
- Returns:
Trueif the structures match and all corresponding leaves are close.Falseif the structures match but any leaves are not close.- Raises:
ValueError – If the PyTree structures of
aandbare not identical.- Return type:
bool
Examples
>>> tree1 = {'a': jnp.array([1.0, 2.0]), 'b': (jnp.array(3.0),)} >>> tree2 = {'a': jnp.array([1.000001, 2.0]), 'b': (jnp.array(3.0),)} >>> tree_allclose(tree1, tree2, atol=1e-5) True
>>> tree3 = {'a': jnp.array([1.0, 2.5]), 'b': (jnp.array(3.0),)} >>> tree_allclose(tree1, tree3) False
>>> tree4 = {'a': jnp.array([1.0, 2.0])} # Different structure >>> try: ... tree_allclose(tree1, tree4) ... except ValueError as e: ... print(e) PyTree structures do not match. a structure: PyTreeDef({'a': *, 'b': (*,)}) b structure: PyTreeDef({'a': *})
- pangolin.util.get_positional_args(target_func, *args, **kwargs)[source]#
Transforms a mix of positional and keyword arguments into a list of purely positional arguments for a target function.
This function inspects the signature of
target_functo correctly map and order the provided arguments. It ensures that the resulting list of positional arguments, when passed totarget_func, will produce the identical outcome as the original call with mixed arguments.Raises a
ValueErroriftarget_funccontains any keyword-only arguments, as these cannot be represented in a purely positional call. Raises aTypeErrorif the provided arguments (*args,**kwargs) are invalid fortarget_func(e.g., missing required arguments, unexpected arguments).- Parameters:
target_func (Callable) – The function whose signature will be used to transform and order the arguments.
args (Any) – Positional arguments to be transformed.
kwargs (Any) – Keyword arguments to be transformed.
- Returns:
A list of arguments in the correct positional order for
target_func.- Raises:
ValueError – If
target_funchas keyword-only arguments.TypeError – If the provided
*argsand**kwargsdo not match the signature oftarget_func.
Examples
>>> def add_five_numbers(a, b, c, d, e): ... return a + b + c + d + e
>>> # Case 1: All arguments provided positionally >>> get_positional_args(add_five_numbers, 1, 2, 3, 4, 5) [1, 2, 3, 4, 5]
>>> # Case 2: Mixed positional and keyword arguments >>> get_positional_args(add_five_numbers, 1, 2, 3, e=5, d=4) [1, 2, 3, 4, 5]
>>> # Case 3: All arguments provided as keywords >>> get_positional_args(add_five_numbers, a=10, b=20, c=30, d=40, e=50) [10, 20, 30, 40, 50]
>>> # Case 4: Function with default arguments >>> def greet(name, greeting="Hello"): ... return f"{greeting}, {name}!" >>> get_positional_args(greet, "Alice") ['Alice', 'Hello'] >>> get_positional_args(greet, "Bob", greeting="Hi") ['Bob', 'Hi'] >>> get_positional_args(greet, name="Charlie", greeting="Greetings") ['Charlie', 'Greetings']
>>> # Case 5: Verification that the transformed arguments yield the same result >>> def multiply(x, y, z): ... return x * y * z >>> original_args = (2,) >>> original_kwargs = {'z': 5, 'y': 3} >>> transformed = get_positional_args(multiply, *original_args, **original_kwargs) >>> multiply(*original_args, **original_kwargs) == multiply(*transformed) True
>>> # Case 6: Attempting to transform a function with keyword-only arguments (raises ValueError) >>> def func_with_kw_only(a, b, *, kw_only_arg, default_kw_only=10): ... pass >>> try: ... get_positional_args(func_with_kw_only, 1, 2, kw_only_arg=3) ... except ValueError as e: ... print(e) Function 'func_with_kw_only' has keyword-only arguments (kw_only_arg), which is not allowed by this transformer.
>>> # Case 7: Missing a required argument (raises TypeError) >>> try: ... get_positional_args(add_five_numbers, 1, 2, 3) ... except TypeError as e: ... # Use '...' to match any varying parts of the error message ... # and check for a key phrase. ... assert "missing a required argument" in str(e) ... print(e.__class__.__name__ + ": " + str(e)) TypeError: ...
>>> # Case 8: Providing an unexpected argument (raises TypeError) >>> try: ... get_positional_args(add_five_numbers, 1, 2, 3, 4, 5, extra_arg=6) ... except TypeError as e: ... print(e.__class__.__name__ + ": " + str(e)) TypeError: ... an unexpected keyword argument 'extra_arg'