Skip to content

Commit

Permalink
add penzai visualization helper function
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Apr 30, 2024
1 parent 8da6646 commit a5eae7b
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/evermore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"util",
"modifier",
"staterror",
"visualization",
# explicitely expose some classes
"Parameter",
"NormalParameter",
Expand All @@ -44,6 +45,7 @@ def __dir__():
pdf,
staterror,
util,
visualization,
)
from evermore.modifier import Modifier # noqa: E402
from evermore.parameter import ( # noqa: E402,
Expand Down
1 change: 1 addition & 0 deletions src/evermore/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def broadcast(self) -> OffsetAndScale:
)


@runtime_checkable
class ModifierLike(Protocol):
def offset_and_scale(self, hist: Array) -> OffsetAndScale: ...
def __call__(self, hist: Array) -> Array: ...
Expand Down
124 changes: 124 additions & 0 deletions src/evermore/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import annotations

import dataclasses
import importlib.util
from typing import Any

import jax.tree_util as jtu
from jaxtyping import Array, PyTree

import evermore as evm
from evermore.custom_types import ModifierLike, PDFLike

__all__ = [
"display",
]


def __dir__():
return __all__


EVERMORE_CLASSES = set(
{
evm.Parameter,
evm.NormalParameter,
evm.effect.Effect,
evm.effect.Identity,
evm.effect.Linear,
evm.effect.AsymmetricExponential,
evm.effect.VerticalTemplateMorphing,
evm.modifier.Modifier,
evm.modifier.Compose,
evm.modifier.Where,
evm.modifier.BooleanMask,
evm.modifier.Transform,
evm.modifier.TransformScale,
evm.modifier.TransformOffset,
evm.pdf.Normal,
evm.pdf.Poisson,
ModifierLike,
PDFLike,
}
)


def display(tree: PyTree) -> None:
"""
Visualize PyTrees of evermore components with penzai in a notebook.
Usage:
.. code-block:: python
import evermore as evm
tree = ...
evm.visualization.display(tree)
"""

penzai_installed = importlib.util.find_spec("penzai") is not None
if not penzai_installed:
msg = "install 'penzai' with:\n\n"
msg += "\tpython -m pip install penzai[notebook]"

try:
from IPython import get_ipython

in_ipython = get_ipython() is not None
except ImportError:
in_ipython = False

if not in_ipython:
print(tree)
return

# now we can pretty-print
from penzai import pz

with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
converted_tree = convert_tree(tree)
pz.ts.display(converted_tree)


def convert_tree(tree: PyTree) -> PyTree:
from functools import partial

for cls in EVERMORE_CLASSES:

def _is_evm_cls(leaf: Any, evm_cls: Any) -> bool:
return isinstance(leaf, evm_cls)

tree = jtu.tree_map(
partial(_convert, cls=cls), tree, is_leaf=partial(_is_evm_cls, evm_cls=cls)
)
return tree


def _convert(leaf: Any, cls: Any) -> Any:
from penzai import pz

if isinstance(leaf, cls) and dataclasses.is_dataclass(leaf):
fields = dataclasses.fields(leaf)

leaf_cls = type(leaf)
attributes: dict[str, Any] = {
"__annotations__": {field.name: field.type for field in fields}
}

if callable(leaf_cls):
attributes["__call__"] = leaf_cls.__call__

def _pretty(x: Any) -> Any:
if isinstance(x, Array) and x.size == 1:
return x.item()
return x

attrs = {k: _pretty(getattr(leaf, k)) for k in attributes["__annotations__"]}

new_cls = pz.pytree_dataclass(
type(leaf_cls.__name__, (pz.Layer,), dict(attributes))
)
return new_cls(**attrs)
return leaf

0 comments on commit a5eae7b

Please sign in to comment.