You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I encountered a small challenge while using a batch normalization layer with nnx.vmap. To illustrate the issue, I have created a minimal example code snippet. Based on my understanding of the documentation for flax-nnx.vmap the issue seems to stem from the handling of BatchStat, which requires special consideration when using vmap.
Currently, I am struggling to make the final example in the attached code work. Does anyone have suggestions on how to adjust the loss function to work correctly with nnx.vmap?
Thank you very much!
import jax.numpy as jnp
from jax import random
from flax import nnx
import optax
#####################################################
#####################################################
#####################################################
# Using no Batch Norm (GOOD)
#####################################################
#####################################################
#####################################################
class Model(nnx.Module):
def __init__(self, din, dmid, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.linear(x)))
return x
class ModelStack(nnx.Module):
def __init__(self, layers, dinput, doutput, rngs: nnx.Rngs):
self.layers = []
self.layers.append(Model(dinput, layers[0], rngs))
for i, layer in enumerate(layers[1:]):
self.layers.append(Model(layers[i - 1], layer, rngs))
self.layers.append(nnx.Linear(layers[-1], doutput, rngs=rngs))
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
model = ModelStack([64, 64], 2, 3, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
#####################################################
# Train Step Function
#####################################################
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model):
def loss(x_, y_):
y_pred = model(x_)
return 0.5 * jnp.inner(y_pred - y_, y_pred - y_)
return jnp.mean(nnx.vmap(loss)(x, y), axis=0)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
return loss
#####################################################
# Training Loop
#####################################################
key = random.PRNGKey(0)
key, sub_key = random.split(key)
x_set = random.uniform(key, (100, 2), jnp.float32)
y_set = random.uniform(key, (100, 3), jnp.float32)
for i in range(1000):
loss = train_step(model, optimizer, x_set, y_set)
print(f"loss: {loss}")
print(f"Done without Batch Norm")
#############################################################
#############################################################
#############################################################
# Using Batch Norm without nnx.vmap in train_step (Good)
#############################################################
#############################################################
#############################################################
class ModelBatchNorm(nnx.Module):
def __init__(self, din, dmid, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return x
class ModelStackBatchNorm(nnx.Module):
def __init__(self, layers, dinput, doutput, rngs: nnx.Rngs):
self.layers = []
self.layers.append(ModelBatchNorm(dinput, layers[0], rngs))
for i, layer in enumerate(layers[1:]):
self.layers.append(ModelBatchNorm(layers[i - 1], layer, rngs))
self.layers.append(nnx.Linear(layers[-1], doutput, rngs=rngs))
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
model = ModelStackBatchNorm([64, 64], 2, 3, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
#####################################################
# Train Step Function
#####################################################
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x)
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
return loss
#####################################################
# Training Loop
#####################################################
key = random.PRNGKey(0)
key, sub_key = random.split(key)
x_set = random.uniform(key, (100, 2), jnp.float32)
y_set = random.uniform(key, (100, 3), jnp.float32)
for i in range(1000):
loss = train_step(model, optimizer, x_set, y_set)
print(f"loss: {loss}")
print(f"Done with Batch Norm but without nnx.vmap")
#####################################################
#####################################################
#####################################################
# Using Batch Norm (BAD)
#####################################################
#####################################################
#####################################################
class ModelBatchNorm(nnx.Module):
def __init__(self, din, dmid, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return x
class ModelStackBatchNorm(nnx.Module):
def __init__(self, layers, dinput, doutput, rngs: nnx.Rngs):
self.layers = []
self.layers.append(ModelBatchNorm(dinput, layers[0], rngs))
for i, layer in enumerate(layers[1:]):
self.layers.append(ModelBatchNorm(layers[i - 1], layer, rngs))
self.layers.append(nnx.Linear(layers[-1], doutput, rngs=rngs))
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
model = ModelStackBatchNorm([64, 64], 2, 3, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
#####################################################
# Train Step Function
#####################################################
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model):
def loss(x_, y_):
y_pred = model(x_)
return 0.5 * jnp.inner(y_pred - y_, y_pred - y_)
return jnp.mean(nnx.vmap(loss)(x, y), axis=0)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
return loss
#####################################################
# Training Loop
#####################################################
key = random.PRNGKey(0)
key, sub_key = random.split(key)
x_set = random.uniform(key, (100, 2), jnp.float32)
y_set = random.uniform(key, (100, 3), jnp.float32)
for i in range(1000):
loss = train_step(model, optimizer, x_set, y_set)
print(f"loss: {loss}")
print(f"Done")
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Dear Community,
I encountered a small challenge while using a batch normalization layer with nnx.vmap. To illustrate the issue, I have created a minimal example code snippet. Based on my understanding of the documentation for flax-nnx.vmap the issue seems to stem from the handling of BatchStat, which requires special consideration when using vmap.
Currently, I am struggling to make the final example in the attached code work. Does anyone have suggestions on how to adjust the loss function to work correctly with nnx.vmap?
Thank you very much!
Beta Was this translation helpful? Give feedback.
All reactions