Skip to content

Commit

Permalink
fix staterror poisson
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed May 8, 2024
1 parent c6f6361 commit b950146
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 12 deletions.
13 changes: 7 additions & 6 deletions examples/bin_by_bin_uncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,16 @@ def __call__(self, hists: dict) -> dict[str, Array]:


hists = {
"signal": jnp.array([3]),
"bkg1": jnp.array([10]),
"bkg2": jnp.array([20]),
"signal": jnp.array([3.0]),
"bkg1": jnp.array([10.0]),
"bkg2": jnp.array([20.0]),
}
histsw2 = {
"signal": jnp.array([5]),
"bkg1": jnp.array([11]),
"bkg2": jnp.array([25]),
"signal": jnp.array([5.0]),
"bkg1": jnp.array([11.0]),
"bkg2": jnp.array([25.0]),
}
observation = jnp.array([34.0])

model = SPlusBModel(hists, histsw2)

Expand Down
22 changes: 16 additions & 6 deletions src/evermore/staterror.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(

self.ntot = sum_over_leaves(self.hists)
self.etot = jnp.sqrt(sum_over_leaves(self.histsw2))
ntot_eff = jnp.round(self.ntot**2 / self.etot**2, decimals=0)
ntot_eff = self.ntot**2 / self.etot**2
self.mask = ntot_eff > self.threshold

# setup params
Expand All @@ -83,19 +83,29 @@ def __init__(
lambda hist: NormalParameter(value=jnp.zeros_like(hist)), self.hists
)
self.poissons_per_process = jtu.tree_map(
lambda hist: Parameter(
value=jnp.zeros_like(hist),
prior=Poisson(lamb=cast(Array, jnp.where(hist > 0.0, hist, 1.0))),
lambda w, w2: Parameter(
value=jnp.zeros_like(w),
prior=Poisson(
lamb=cast(
Array,
jnp.where(
(w**2 / jnp.sqrt(w2**2)) > 0.0,
(w**2 / jnp.sqrt(w2**2)),
1.0,
),
)
),
),
self.hists,
self.histsw2,
)

def modifier(self, getter: Callable) -> ModifierLike:
# see: https://github.com/cms-analysis/HiggsAnalysis-CombinedLimit/pull/929
# and: https://cms-analysis.github.io/HiggsAnalysis-CombinedLimit/latest/part2/bin-wise-stats/#usage-instructions

# poisson case per process
# if w > 0.0, then poisson, else noop (no effect)
# if w > 0.0, then poisson, else Identity (no effect)
# since w <= 0 leads to NaNs in derivatives, we need to mask them
w = getter(self.hists)
poisson_params = getter(self.poissons_per_process)
Expand Down Expand Up @@ -141,7 +151,7 @@ def modifier(self, getter: Callable) -> ModifierLike:
# if n_i_eff > threshold or e_i > n_i or n_i <= 0.0:
# apply per process gaussian(width=e_i/n_i)
# else:
# apply per process poisson(lamb=n_i)
# apply per process poisson(lamb=n_i_eff)
per_process_mask = (
((w**2 / w2**2) > self.threshold) | (jnp.sqrt(w2) > w) | (w <= 0)
)
Expand Down

0 comments on commit b950146

Please sign in to comment.