Skip to content

Commit

Permalink
jax.tree_util.* -> jax.tree.*
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Aug 28, 2024
1 parent 36d55d7 commit 8b5e1d9
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 24 deletions.
2 changes: 1 addition & 1 deletion docs/binned_likelihood.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def NLL(dynamic_params, static_params, hists, observation):
# second product of Eq. 1 (constraint)
constraints = evm.loss.get_log_probs(model)
# for parameters with `.value.size > 1` (jnp.sum the constraints)
constraints = jtu.tree_map(jnp.sum, constraints)
constraints = jax.tree.map(jnp.sum, constraints)
loss_val += evm.util.sum_over_leaves(constraints)
return -jnp.sum(loss_val)
```
Expand Down
5 changes: 2 additions & 3 deletions docs/building_blocks.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ Correlate a Parameter
A more general case of correlating any PyTree of parameters is implemented as follows:
```{code-block} python
from typing import NamedTuple
import jax.tree_util as jtu


class Params(NamedTuple):
Expand All @@ -110,11 +109,11 @@ Correlate a Parameter

params = Params(mu=evm.Parameter(1.0), syst=evm.NormalParameter(0.0))

flat_params, tree_def = jtu.tree_flatten(params, evm.parameter.is_parameter)
flat_params, tree_def = jax.tree.flatten(params, evm.parameter.is_parameter)

# correlate the parameters
correlated_flat_params = evm.parameter.correlate(*flat_params)
correlated_params = jtu.tree_unflatten(tree_def, correlated_flat_params)
correlated_params = jax.tree.unflatten(tree_def, correlated_flat_params)

# now correlated_params.mu and correlated_params.syst are correlated,
# they share the same value
Expand Down
2 changes: 1 addition & 1 deletion examples/dnn_weights_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def loss_fn(model, x, y):
mse = jax.numpy.mean((y - pred_y) ** 2)
constraints = evm.loss.get_log_probs(model)
# sum them all up for each weight
constraints = jax.tree_util.tree_map(jnp.sum, constraints)
constraints = jax.tree.map(jnp.sum, constraints)
return mse + evm.util.sum_over_leaves(constraints)


Expand Down
10 changes: 4 additions & 6 deletions src/evermore/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import Array, ArrayLike, PyTree

from evermore.custom_types import ModifierLike, OffsetAndScale
Expand Down Expand Up @@ -73,7 +72,6 @@ class ModifierBase(ApplyFn, MatMulCompose, AbstractModifier, SupportsTreescope):
import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import Array
import evermore as evm
Expand All @@ -86,7 +84,7 @@ class Clip(evm.modifier.ModifierBase):
def offset_and_scale(self, hist: Array) -> evm.custom_types.OffsetAndScale:
os = self.modifier.offset_and_scale(hist)
return jtu.tree_map(lambda x: jnp.clip(x, self.min_sf, self.max_sf), os)
return jax.tree.map(lambda x: jnp.clip(x, self.min_sf, self.max_sf), os)
parameter = evm.Parameter(value=1.1)
Expand Down Expand Up @@ -201,7 +199,7 @@ def offset_and_scale(self, hist: Array) -> OffsetAndScale:
def _where(true: Array, false: Array) -> Array:
return jnp.where(self.condition, true, false)

return jtu.tree_map(_where, true_os, false_os)
return jax.tree.map(_where, true_os, false_os)


class BooleanMask(ModifierBase):
Expand Down Expand Up @@ -280,7 +278,7 @@ class Transform(ModifierBase):

def offset_and_scale(self, hist: Array) -> OffsetAndScale:
os = self.modifier.offset_and_scale(hist)
return jtu.tree_map(self.transform_fn, os)
return jax.tree.map(self.transform_fn, os)


class TransformOffset(ModifierBase):
Expand Down Expand Up @@ -377,7 +375,7 @@ def offset_and_scale(self, hist: Array) -> OffsetAndScale:
groups = defaultdict(list)
# first group modifiers into same tree structures
for mod in self.unroll_modifiers():
groups[hash(jtu.tree_structure(mod))].append(mod)
groups[hash(jax.tree.structure(mod))].append(mod)
# then do the `jax.lax.scan` loops
for _, group_mods in groups.items():
# skip empty groups
Expand Down
16 changes: 7 additions & 9 deletions src/evermore/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import Array, ArrayLike, PRNGKeyArray, PyTree

from evermore.custom_types import PDFLike
Expand Down Expand Up @@ -139,15 +138,15 @@ def model(diffable, static, hists) -> Array:
...
"""
# 1. set the filter_spec to False for all non-static leaves
filter_spec = jtu.tree_map(lambda _: False, tree)
filter_spec = jax.tree.map(lambda _: False, tree)

# 2. set the filter_spec to True for each parameter value
def _replace_value(leaf: Any) -> Any:
if isinstance(leaf, Parameter):
leaf = eqx.tree_at(lambda p: p.value, leaf, not leaf.frozen)
return leaf

return jtu.tree_map(_replace_value, filter_spec, is_leaf=is_parameter)
return jax.tree.map(_replace_value, filter_spec, is_leaf=is_parameter)


def partition(tree: PyTree) -> tuple[PyTree, PyTree]:
Expand Down Expand Up @@ -175,11 +174,11 @@ def sample(tree: PyTree, key: PRNGKeyArray) -> PyTree:
"""
# partition the tree into parameters and the rest
params_tree, rest_tree = eqx.partition(tree, is_parameter, is_leaf=is_parameter)
params_structure = jax.tree_util.tree_structure(params_tree)
params_structure = jax.tree.structure(params_tree)
n_params = params_structure.num_leaves # type: ignore[attr-defined]

keys = jax.random.split(key, n_params)
keys_tree = jax.tree_util.tree_unflatten(params_structure, keys)
keys_tree = jax.tree.unflatten(params_structure, keys)

def _sample(param: Parameter, key: Parameter) -> Array:
if isinstance(param.prior, PDFLike):
Expand Down Expand Up @@ -209,7 +208,7 @@ def _sample(param: Parameter, key: Parameter) -> Array:
return eqx.tree_at(lambda p: p.value, param, sampled_value)

# sample for each parameter
sampled_params_tree = jtu.tree_map(
sampled_params_tree = jax.tree.map(
_sample, params_tree, keys_tree, is_leaf=is_parameter
)

Expand Down Expand Up @@ -246,7 +245,6 @@ def model(*parameters: PyTree[evm.Parameter]):
# More general case of correlating any PyTree of parameters
from typing import NamedTuple
import jax.tree_util as jtu
class Params(NamedTuple):
Expand All @@ -256,11 +254,11 @@ class Params(NamedTuple):
params = Params(mu=evm.Parameter(1.0), syst=evm.NormalParameter(0.0))
def model(params: Params):
flat_params, tree_def = jtu.tree_flatten(params, evm.parameter.is_parameter)
flat_params, tree_def = jax.tree.flatten(params, evm.parameter.is_parameter)
# correlate the parameters
correlated_flat_params = evm.parameter.correlate(*flat_params)
correlated_params = jtu.tree_unflatten(tree_def, correlated_flat_params)
correlated_params = jax.tree.unflatten(tree_def, correlated_flat_params)
# now correlated_params.mu and correlated_params.syst are correlated, i.e., they share the same value
assert correlated_params.mu.value == correlated_params.syst.value
Expand Down
8 changes: 4 additions & 4 deletions src/evermore/staterror.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import cast

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import Array, PyTree

from evermore.custom_types import ModifierLike
Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(
threshold: float = 10.0,
) -> None:
assert (
jtu.tree_structure(hists) == jtu.tree_structure(histsw2) # type: ignore[operator]
jax.tree.structure(hists) == jax.tree.structure(histsw2) # type: ignore[operator]
), "The PyTree structure of hists and histsw2 must be the same!"
self.hists = hists
self.histsw2 = histsw2
Expand All @@ -80,10 +80,10 @@ def __init__(

# setup params
self.gaussians_global = NormalParameter(value=jnp.zeros_like(self.ntot))
self.gaussians_per_process = jtu.tree_map(
self.gaussians_per_process = jax.tree.map(
lambda hist: NormalParameter(value=jnp.zeros_like(hist)), self.hists
)
self.poissons_per_process = jtu.tree_map(
self.poissons_per_process = jax.tree.map(
lambda w, w2: Parameter(
value=jnp.zeros_like(w),
prior=Poisson(
Expand Down

0 comments on commit 8b5e1d9

Please sign in to comment.