From 5faecff3e79c97a4c1a1e71b9085ffcce545c14d Mon Sep 17 00:00:00 2001 From: felixzinn <151917409+felixzinn@users.noreply.github.com> Date: Fri, 15 Nov 2024 10:20:07 +0100 Subject: [PATCH] add bounds for parameter in toy example (#19) --- examples/toy_generation.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/toy_generation.py b/examples/toy_generation.py index 5c5faa8..862e78d 100644 --- a/examples/toy_generation.py +++ b/examples/toy_generation.py @@ -1,5 +1,6 @@ import equinox as eqx import jax +import jax.numpy as jnp from jaxtyping import Array, PRNGKeyArray from model import hists, model, observation @@ -7,6 +8,10 @@ key = jax.random.PRNGKey(0) +# set lower and upper bounds for the mu parameter +model = eqx.tree_at(lambda t: t.mu.lower, model, jnp.array([0.0])) +model = eqx.tree_at(lambda t: t.mu.upper, model, jnp.array([10.0])) + # generate a new model with sampled parameters according to their constraint pdfs toymodel = evm.parameter.sample(model, key)