Skip to content

Commit

Permalink
add modifier tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Apr 29, 2024
1 parent fa6f340 commit 8ccffcd
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 15 deletions.
6 changes: 3 additions & 3 deletions src/evermore/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class Transform(ModifierBase):
# -> Array([ 5.049494, 20.197975, 30.296963], dtype=float32)
"""

transform_fn: Callable[[OffsetAndScale], OffsetAndScale] = eqx.field(static=True)
transform_fn: Callable = eqx.field(static=True)
modifier: ModifierLike

def offset_and_scale(self, hist: Array) -> OffsetAndScale:
Expand All @@ -283,7 +283,7 @@ def offset_and_scale(self, hist: Array) -> OffsetAndScale:


class TransformOffset(ModifierBase):
transform_fn: Callable[[Array], Array] = eqx.field(static=True)
transform_fn: Callable = eqx.field(static=True)
modifier: ModifierLike

def offset_and_scale(self, hist: Array) -> OffsetAndScale:
Expand All @@ -292,7 +292,7 @@ def offset_and_scale(self, hist: Array) -> OffsetAndScale:


class TransformScale(ModifierBase):
transform_fn: Callable[[Array], Array] = eqx.field(static=True)
transform_fn: Callable = eqx.field(static=True)
modifier: ModifierLike

def offset_and_scale(self, hist: Array) -> OffsetAndScale:
Expand Down
7 changes: 4 additions & 3 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import jax.numpy as jnp
import pytest

import evermore as evm

Expand All @@ -18,9 +19,9 @@ def test_get_log_probs():
}

log_probs = evm.loss.get_log_probs(params)
assert log_probs["a"] == -0.125
assert log_probs["b"] == 0.0
assert log_probs["c"] == 0.0
assert log_probs["a"] == pytest.approx(-0.125)
assert log_probs["b"] == pytest.approx(0.0)
assert log_probs["c"] == pytest.approx(0.0)


def test_get_boundary_constraints():
Expand Down
66 changes: 57 additions & 9 deletions tests/test_modifier.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,77 @@
from __future__ import annotations

import jax.numpy as jnp
import numpy as np

import evermore as evm


def test_Modifier():
pass
param = evm.Parameter(value=1.1)
modifier = param.scale()

hist = jnp.array([1, 2, 3])

assert isinstance(modifier, evm.Modifier)
np.testing.assert_allclose(modifier(hist), jnp.array([1.1, 2.2, 3.3]))


def test_Where():
pass
param1 = evm.Parameter(value=1.0)
param2 = evm.Parameter(value=1.1)
modifier1 = param1.scale()
modifier2 = param2.scale()

hist = jnp.array([1, 2, 3])

where_mod = evm.modifier.Where(hist > 1.5, modifier2, modifier1)
np.testing.assert_allclose(where_mod(hist), jnp.array([1, 2.2, 3.3]))


def test_BooleanMask():
pass
param = evm.Parameter(value=1.1)
modifier = param.scale()

hist = jnp.array([1, 2, 3])

masked_mod = evm.modifier.BooleanMask(jnp.array([True, False, True]), modifier)
np.testing.assert_allclose(masked_mod(hist), jnp.array([1.1, 2, 3.3]))


def test_Transform():
pass
param = evm.Parameter(value=1.1)
modifier = param.scale()

hist = jnp.array([1, 2, 3])

sqrt_modifier = evm.modifier.Transform(jnp.sqrt, modifier)
np.testing.assert_allclose(
sqrt_modifier(hist), jnp.array([1.0488088, 2.0976176, 3.1464264])
)

def test_TransformOffset():
pass

def test_mix_modifiers():
param = evm.Parameter(value=1.1)
modifier = param.scale()

def test_TransformScale():
pass
hist = jnp.array([1, 2, 3])

sqrt_modifier = evm.modifier.Transform(jnp.sqrt, modifier)
sqrt_masked_modifier = evm.modifier.BooleanMask(
jnp.array([True, False, True]), sqrt_modifier
)
np.testing.assert_allclose(
sqrt_masked_modifier(hist), jnp.array([1.0488088, 2, 3.1464264])
)


def test_Compose():
pass
param1 = evm.Parameter(value=1.0)
param2 = evm.Parameter(value=1.1)
modifier1 = param1.scale()
modifier2 = param2.scale()

hist = jnp.array([1, 2, 3])

composition = modifier1 @ modifier2
np.testing.assert_allclose(composition(hist), jnp.array([1.1, 2.2, 3.3]))

0 comments on commit 8ccffcd

Please sign in to comment.