-
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 2 replies
-
Hey @phlippe, the # Initialization
feat_dim = 128
module = MyConv(c_out=feat_dim, kernel_size=3)
inp = random.normal(random.PRNGKey(123), (1, 5, 5, feat_dim))
params = module.init(random.PRNGKey(0), inp)
# Gradient function of center pixel with respect to input image
grad_fn = jax.grad(lambda params, x: module.apply(params, x)[:,2,2,:].sum())
grads = grad_fn(params, inp)
for i in range(3):
for j in range(3):
print()
print(grads['params']['Conv_0']['kernel'][..., i, j] == 0) Output:
|
Beta Was this translation helpful? Give feedback.
-
Hi @cgarciae, thanks for your response! The calculation of the gradients w.r.t. the input is more of a debugging step for a single layer, and becomes relevant if you stack multiple layers. For example, consider a two-layer network with these masks. When calculating gradients w.r.t. the parameters of the input layer, we backpropagate through the input features to the output layer, since we can change these features by changing the parameters of the input layer, and the features have a direct impact on the output/loss. Hence, you can see the gradient above as an intermediate step that is done in a deep NN. The intended gradient we would like to have seen w.r.t. the input is zeros for the second-to-last row, since the weight for these elements is set to zero by the mask, and hence should not have any effect on the output. In other words, this is the expected output:
Now consider the behavior above with the gradients w.r.t. the inputs. Essentially, the gradients indicate that changing intermediate features at position [3,2] (row 3, col 2, zero-indexed) lead to a change for output features [2,2]. However, this is not intended since in the output layer, the kernel has been masked such that position [3,2] has a weight of zero for output [2,2]. This will influence the gradients w.r.t. the parameters of the input layer, and if you had stacked more layers, for example in a PixelCNN here, it effects more layers further down the network. However, in practice, the network is still trainable because the leaked gradients are a couple of magnitudes smaller than the non-masked elements. After a bit of further investigation, it seems that this behavior does not origin from flax itself, but jax's convolution operator. So I opened an issue on the jax repo to get behind this behavior |
Beta Was this translation helpful? Give feedback.
-
I think this is an artifact of the winograd convolution kernel that NVIDIA implemented to accelerate small convolution kernels. Note that the values that are supposed to be zero are actually pretty small (10^-9) |
Beta Was this translation helpful? Give feedback.
-
Hi @jheek. May I ask if there is currently an available implementation of Winograd convolution in Flax? Or I need to manually implement the Winograd convolutional layer and replace the original convolutional layer in the model definition? |
Beta Was this translation helpful? Give feedback.
I think this is an artifact of the winograd convolution kernel that NVIDIA implemented to accelerate small convolution kernels. Note that the values that are supposed to be zero are actually pretty small (10^-9)