-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add penzai visualization helper function
- Loading branch information
1 parent
8da6646
commit a5eae7b
Showing
3 changed files
with
127 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from __future__ import annotations | ||
|
||
import dataclasses | ||
import importlib.util | ||
from typing import Any | ||
|
||
import jax.tree_util as jtu | ||
from jaxtyping import Array, PyTree | ||
|
||
import evermore as evm | ||
from evermore.custom_types import ModifierLike, PDFLike | ||
|
||
__all__ = [ | ||
"display", | ||
] | ||
|
||
|
||
def __dir__(): | ||
return __all__ | ||
|
||
|
||
EVERMORE_CLASSES = set( | ||
{ | ||
evm.Parameter, | ||
evm.NormalParameter, | ||
evm.effect.Effect, | ||
evm.effect.Identity, | ||
evm.effect.Linear, | ||
evm.effect.AsymmetricExponential, | ||
evm.effect.VerticalTemplateMorphing, | ||
evm.modifier.Modifier, | ||
evm.modifier.Compose, | ||
evm.modifier.Where, | ||
evm.modifier.BooleanMask, | ||
evm.modifier.Transform, | ||
evm.modifier.TransformScale, | ||
evm.modifier.TransformOffset, | ||
evm.pdf.Normal, | ||
evm.pdf.Poisson, | ||
ModifierLike, | ||
PDFLike, | ||
} | ||
) | ||
|
||
|
||
def display(tree: PyTree) -> None: | ||
""" | ||
Visualize PyTrees of evermore components with penzai in a notebook. | ||
Usage: | ||
.. code-block:: python | ||
import evermore as evm | ||
tree = ... | ||
evm.visualization.display(tree) | ||
""" | ||
|
||
penzai_installed = importlib.util.find_spec("penzai") is not None | ||
if not penzai_installed: | ||
msg = "install 'penzai' with:\n\n" | ||
msg += "\tpython -m pip install penzai[notebook]" | ||
|
||
try: | ||
from IPython import get_ipython | ||
|
||
in_ipython = get_ipython() is not None | ||
except ImportError: | ||
in_ipython = False | ||
|
||
if not in_ipython: | ||
print(tree) | ||
return | ||
|
||
# now we can pretty-print | ||
from penzai import pz | ||
|
||
with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()): | ||
converted_tree = convert_tree(tree) | ||
pz.ts.display(converted_tree) | ||
|
||
|
||
def convert_tree(tree: PyTree) -> PyTree: | ||
from functools import partial | ||
|
||
for cls in EVERMORE_CLASSES: | ||
|
||
def _is_evm_cls(leaf: Any, evm_cls: Any) -> bool: | ||
return isinstance(leaf, evm_cls) | ||
|
||
tree = jtu.tree_map( | ||
partial(_convert, cls=cls), tree, is_leaf=partial(_is_evm_cls, evm_cls=cls) | ||
) | ||
return tree | ||
|
||
|
||
def _convert(leaf: Any, cls: Any) -> Any: | ||
from penzai import pz | ||
|
||
if isinstance(leaf, cls) and dataclasses.is_dataclass(leaf): | ||
fields = dataclasses.fields(leaf) | ||
|
||
leaf_cls = type(leaf) | ||
attributes: dict[str, Any] = { | ||
"__annotations__": {field.name: field.type for field in fields} | ||
} | ||
|
||
if callable(leaf_cls): | ||
attributes["__call__"] = leaf_cls.__call__ | ||
|
||
def _pretty(x: Any) -> Any: | ||
if isinstance(x, Array) and x.size == 1: | ||
return x.item() | ||
return x | ||
|
||
attrs = {k: _pretty(getattr(leaf, k)) for k in attributes["__annotations__"]} | ||
|
||
new_cls = pz.pytree_dataclass( | ||
type(leaf_cls.__name__, (pz.Layer,), dict(attributes)) | ||
) | ||
return new_cls(**attrs) | ||
return leaf |