From b950146e31e336161c43833e9d4b63d0d8e1292d Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Wed, 8 May 2024 17:33:09 +0200 Subject: [PATCH] fix staterror poisson --- examples/bin_by_bin_uncs.py | 13 +++++++------ src/evermore/staterror.py | 22 ++++++++++++++++------ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/examples/bin_by_bin_uncs.py b/examples/bin_by_bin_uncs.py index e0270fe..c70a099 100644 --- a/examples/bin_by_bin_uncs.py +++ b/examples/bin_by_bin_uncs.py @@ -38,15 +38,16 @@ def __call__(self, hists: dict) -> dict[str, Array]: hists = { - "signal": jnp.array([3]), - "bkg1": jnp.array([10]), - "bkg2": jnp.array([20]), + "signal": jnp.array([3.0]), + "bkg1": jnp.array([10.0]), + "bkg2": jnp.array([20.0]), } histsw2 = { - "signal": jnp.array([5]), - "bkg1": jnp.array([11]), - "bkg2": jnp.array([25]), + "signal": jnp.array([5.0]), + "bkg1": jnp.array([11.0]), + "bkg2": jnp.array([25.0]), } +observation = jnp.array([34.0]) model = SPlusBModel(hists, histsw2) diff --git a/src/evermore/staterror.py b/src/evermore/staterror.py index 96291a5..30156d5 100644 --- a/src/evermore/staterror.py +++ b/src/evermore/staterror.py @@ -74,7 +74,7 @@ def __init__( self.ntot = sum_over_leaves(self.hists) self.etot = jnp.sqrt(sum_over_leaves(self.histsw2)) - ntot_eff = jnp.round(self.ntot**2 / self.etot**2, decimals=0) + ntot_eff = self.ntot**2 / self.etot**2 self.mask = ntot_eff > self.threshold # setup params @@ -83,11 +83,21 @@ def __init__( lambda hist: NormalParameter(value=jnp.zeros_like(hist)), self.hists ) self.poissons_per_process = jtu.tree_map( - lambda hist: Parameter( - value=jnp.zeros_like(hist), - prior=Poisson(lamb=cast(Array, jnp.where(hist > 0.0, hist, 1.0))), + lambda w, w2: Parameter( + value=jnp.zeros_like(w), + prior=Poisson( + lamb=cast( + Array, + jnp.where( + (w**2 / jnp.sqrt(w2**2)) > 0.0, + (w**2 / jnp.sqrt(w2**2)), + 1.0, + ), + ) + ), ), self.hists, + self.histsw2, ) def modifier(self, getter: Callable) -> ModifierLike: @@ -95,7 +105,7 @@ def modifier(self, getter: Callable) -> ModifierLike: # and: https://cms-analysis.github.io/HiggsAnalysis-CombinedLimit/latest/part2/bin-wise-stats/#usage-instructions # poisson case per process - # if w > 0.0, then poisson, else noop (no effect) + # if w > 0.0, then poisson, else Identity (no effect) # since w <= 0 leads to NaNs in derivatives, we need to mask them w = getter(self.hists) poisson_params = getter(self.poissons_per_process) @@ -141,7 +151,7 @@ def modifier(self, getter: Callable) -> ModifierLike: # if n_i_eff > threshold or e_i > n_i or n_i <= 0.0: # apply per process gaussian(width=e_i/n_i) # else: - # apply per process poisson(lamb=n_i) + # apply per process poisson(lamb=n_i_eff) per_process_mask = ( ((w**2 / w2**2) > self.threshold) | (jnp.sqrt(w2) > w) | (w <= 0) )