Skip to content

Commit

Permalink
fix mean logit calculation in neurd loss for rnad
Browse files Browse the repository at this point in the history
  • Loading branch information
spktrm committed Dec 13, 2023
1 parent 7c58b6c commit 8156646
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions open_spiel/python/algorithms/rnad/rnad.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,9 @@ def get_loss_nerd(logit_list: Sequence[chex.Array],
"""Define the nerd loss."""
assert isinstance(importance_sampling_correction, list)
loss_pi_list = []

num_valid_actions = jnp.sum(legal_actions, axis=-1, keepdims=True)

for k, (logit_pi, pi, q_vr, is_c) in enumerate(
zip(logit_list, policy_list, q_vr_list, importance_sampling_correction)):
assert logit_pi.shape[0] == q_vr.shape[0]
Expand All @@ -570,9 +573,12 @@ def get_loss_nerd(logit_list: Sequence[chex.Array],
adv_pi = is_c * adv_pi # importance sampling correction
adv_pi = jnp.clip(adv_pi, a_min=-clip, a_max=clip)
adv_pi = lax.stop_gradient(adv_pi)

logits = logit_pi - jnp.mean(
logit_pi * legal_actions, axis=-1, keepdims=True)

valid_logit_sum = jnp.sum(logit_pi * legal_actions, axis=-1, keepdims=True)
mean_logit = valid_logit_sum / num_valid_actions

# Subtract only the mean of the valid logits
logits = logit_pi - mean_logit

threshold_center = jnp.zeros_like(logits)

Expand Down

0 comments on commit 8156646

Please sign in to comment.