Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed May 2, 2024
1 parent 18f58e6 commit b36faaf
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 56 deletions.
37 changes: 20 additions & 17 deletions docs/building_blocks.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
---
jupytext:
formats: md:myst
text_representation:
extension: .md
format_name: myst
kernelspec:
display_name: Python 3
language: python
name: python3
---

(building-blocks)=
# Building Blocks

Expand Down Expand Up @@ -117,26 +129,12 @@ Inspect a (PyTree of) `evm.Parameters` with [`penzai`'s treescope](https://penza
You can even add custom visualizers, such as:

```{code-block} python
from penzai import pz
import evermore as evm
import plotly.express as px
tree = {"a": evm.NormalParameter(), "b": evm.NormalParameter()}
# custom plotly visualization of prior PDF
def plot_prior_pdf(value, path):
if isinstance(value, evm.Parameter) and isinstance(value.prior, evm.custom_types.PDFLike):
x = jnp.arange(-3, 3, 0.1)
return pz.ts.IPythonVisualization(
px.bar(
x=x, y=jnp.exp(value.prior.log_prob(x)),
width=400, height=200
).update_layout(margin=dict(l=20, r=20, t=20, b=20))
)
with pz.ts.active_autovisualizer.set_scoped(plot_prior_pdf):
pz.ts.display(tree)
evm.visualization.display(tree)
```
:::

Expand Down Expand Up @@ -271,11 +269,14 @@ Modifier that scales a histogram based on vertical template morphing (Normal con

Multiple modifiers should be combined using `evm.modifier.Compose` or the `@` operator:

```{code-block} python
```{code-cell} ipython3
import jax
import jax.numpy as jnp
import evermore as evm
jax.config.update("jax_enable_x64", True)
param = evm.NormalParameter(value=0.1)
modifier1 = param.morphing(
Expand All @@ -286,6 +287,8 @@ modifier1 = param.morphing(
modifier2 = param.scale_log(up=1.1, down=0.9)
# apply the composed modifier
(modifier1 @ modifier2)(jnp.array([10, 20, 30])
(modifier1 @ modifier2)(jnp.array([10, 20, 30]))
# -> Array([10.259877, 20.500944, 30.760822], dtype=float32)
evm.visualization.display(modifier1 @ modifier2)
```
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
html_favicon = "../assets/favicon.png"

extensions = [
"myst_parser",
"myst_nb",
# "myst_parser",
"sphinx.ext.autodoc",
"sphinx.ext.intersphinx",
"sphinx.ext.viewcode",
Expand Down
71 changes: 53 additions & 18 deletions docs/tips_and_tricks.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,59 @@
---
jupytext:
formats: md:myst
text_representation:
extension: .md
format_name: myst
kernelspec:
display_name: Python 3
language: python
name: python3
---


# Tips and Tricks

Here are some advanced tips and tricks.


## penzai Visualization

Use `penzai` to visualize evermore components!

```{code-cell} ipython3
import jax
import jax.numpy as jnp
import evermore as evm
import equinox as eqx
jax.config.update("jax_enable_x64", True)
mu = evm.Parameter(value=1.1)
sigma1 = evm.NormalParameter(value=0.1)
sigma2 = evm.NormalParameter(value=0.2)
hist = jnp.array([10, 20, 30])
mu_mod = mu.scale(offset=0, slope=1)
sigma1_mod = sigma1.scale_log(up=1.1, down=0.9)
sigma2_mod = sigma2.scale_log(up=1.05, down=0.95)
composition = evm.modifier.Compose(
mu_mod,
sigma1_mod,
evm.modifier.Where(hist < 15, sigma2_mod, sigma1_mod),
)
composition = evm.modifier.Compose(
composition,
evm.Modifier(parameter=sigma1, effect=evm.effect.AsymmetricExponential(up=1.2, down=0.8)),
)
evm.visualization.display(composition)
```



## Parameter Partitioning

For optimization it is necessary to differentiate only against meaningful leaves of the PyTree of `evm.Parameters`.
Expand Down Expand Up @@ -47,7 +98,7 @@ If you need to further exclude parameter from being optimized you can either set
Evert component of evermore is compatible with JAX transformations. That means you can `jax.jit`, `jax.vmap`, ... _everything_.
You can e.g. sample the parameter values multiple times vectorized from its prior PDF:

```{code-block} python
```{code-cell} ipython3
import jax
import evermore as evm
Expand All @@ -59,23 +110,7 @@ rng_keys = jax.random.split(rng_key, 100)
vec_sample = jax.vmap(evm.parameter.sample, in_axes=(None, 0))
print(vec_sample(params, rng_keys))
# {'a': NormalParameter(
# value=f32[100,1],
# name=None,
# lower=f32[100,1],
# upper=f32[100,1],
# prior=Normal(mean=f32[100,1], width=f32[100,1]),
# frozen=False,
# ),
# 'b': NormalParameter(
# value=f32[100,1],
# name=None,
# lower=f32[100,1],
# upper=f32[100,1],
# prior=Normal(mean=f32[100,1], width=f32[100,1]),
# frozen=False,
# )}
evm.visualization.display(vec_sample(params, rng_keys))
```

Many minimizers from the JAX ecosystem are e.g. batchable (`optax`, `optimistix`), which allows you vectorize _full fits_, e.g., for embarrassingly parallel likleihood profiles.
Expand Down
6 changes: 5 additions & 1 deletion pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@ sphinx-copybutton = "*"
sphinx-book-theme = "*"
sphinx-design = "*"
sphinx-togglebutton = "*"
myst-nb = "*"

[pypi-dependencies]
penzai = "*"

[tasks]
postinstall = "pip install -e '.[dev]' && pip install pre-commit && pre-commit install"
test = "pytest"
lint = "ruff check . --fix --show-fixes"
checkall = "pre-commit run --all-files"
builddocs = "sphinx-build -M html ./docs ./build -W --keep-going"
builddocs = "rm -rf build/ && sphinx-build -M html ./docs ./build -W --keep-going"
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ dependencies = [
test = ["pytest >=6", "pytest-cov >=3"]
dev = ["pytest >=6", "pytest-cov >=3", "optax", "jaxopt >=0.6"]
docs = [
"sphinx>=7.0",
"myst_parser>=0.13",
"sphinx",
"myst-parser",
"myst-nb",
"sphinx_copybutton",
"sphinx_autodoc_typehints",
"sphinx-book-theme",
"sphinx-design",
"sphinx-togglebutton",
"penzai",
]

[project.urls]
Expand Down
11 changes: 6 additions & 5 deletions src/evermore/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,19 +351,20 @@ class Compose(ModifierBase):

def __init__(self, *modifiers: ModifierLike) -> None:
self.modifiers = list(modifiers)
# unroll nested compositions

def unroll_modifiers(self) -> list[ModifierLike]:
_modifiers = []
for mod in self.modifiers:
if isinstance(mod, Compose):
_modifiers.extend(mod.modifiers)
else:
assert isinstance(mod, ModifierBase)
assert isinstance(mod, ModifierLike)
_modifiers.append(mod)
# by now all are modifiers
self.modifiers = _modifiers
return _modifiers

def __len__(self) -> int:
return len(self.modifiers)
return len(self.unroll_modifiers())

def offset_and_scale(self, hist: Array) -> OffsetAndScale:
from collections import defaultdict
Expand All @@ -374,7 +375,7 @@ def offset_and_scale(self, hist: Array) -> OffsetAndScale:

groups = defaultdict(list)
# first group modifiers into same tree structures
for mod in self.modifiers:
for mod in self.unroll_modifiers():
groups[hash(jtu.tree_structure(mod))].append(mod)
# then do the `jax.lax.scan` loops
for _, group_mods in groups.items():
Expand Down
33 changes: 21 additions & 12 deletions src/evermore/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dataclasses
import importlib.util
import threading
from typing import Any

import jax.tree_util as jtu
Expand Down Expand Up @@ -36,17 +37,23 @@ def __dir__():
return __all__


penzai_installed = importlib.util.find_spec("penzai") is not None
@dataclasses.dataclass
class EvermoreClassesContext(threading.local):
cls_types: list[Any] = dataclasses.field(default_factory=list)

EVERMORE_CLASSES = set(
{
Parameter,

Context = EvermoreClassesContext()


Context.cls_types.extend(
[
NormalParameter,
Effect,
Parameter,
Identity,
Linear,
AsymmetricExponential,
VerticalTemplateMorphing,
Effect,
Modifier,
Compose,
Where,
Expand All @@ -58,7 +65,7 @@ def __dir__():
Poisson,
ModifierLike,
PDFLike,
}
]
)


Expand All @@ -76,11 +83,13 @@ def display(tree: PyTree) -> None:
tree = ...
evm.visualization.display(tree)
"""
penzai_installed = importlib.util.find_spec("penzai") is not None

if not penzai_installed:
msg = "install 'penzai' with:\n\n"
msg += "\tpython -m pip install penzai[notebook]"
raise ModuleNotFoundError(msg)

try:
from IPython import get_ipython

Expand All @@ -96,20 +105,20 @@ def display(tree: PyTree) -> None:
from penzai import pz

with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
converted_tree = convert_tree(tree)
pz.ts.display(converted_tree)
pz_tree = convert_tree(tree)
pz.ts.display(pz_tree)


def convert_tree(tree: PyTree) -> PyTree:
from functools import partial

for cls in EVERMORE_CLASSES:
for cls in Context.cls_types:

def _is_evm_cls(leaf: Any, evm_cls: Any) -> bool:
return isinstance(leaf, evm_cls)
def _is_evm_cls(leaf: Any, cls: Any) -> bool:
return isinstance(leaf, cls)

tree = jtu.tree_map(
partial(_convert, cls=cls), tree, is_leaf=partial(_is_evm_cls, evm_cls=cls)
partial(_convert, cls=cls), tree, is_leaf=partial(_is_evm_cls, cls=cls)
)
return tree

Expand Down

0 comments on commit b36faaf

Please sign in to comment.