Tensorflow Batch norm not working

Have you experienced that when using tf.contrib.layers.batch_norm during test time, the performance is diminished or is completely a mess? If yes, this post is for you.

Quick summary:
1) As suggested in [1]:  you have do add this additional code:
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)
2) You have to use a placeholder to indicate that you are training or not, so you put this in the batch_norm param is_training=your_place_holder, True for training.

3) If after this it doesn't show a good behavior, change the default decay=0.999, to a lower value such as 0.5.

Other way to apply batch norm is presented in [2], for a manual setup. Nonetheless, if by using this second approach you still have a bad performance, use the point 3). If you want to use this second approach with Convolutional Neural Networks (CNN) layers, change tf.nn.moments(inputs,[0]) to tf.nn.moments(inputs,[0,1,2]) .

Problem description
I have seen tutorials and code on github that put the "test" in the same script and session of the training stage. So while using a default setup for the batch_norm layer, it doesn't show too much problems. In addition, I have been using a default setup (without the previous 1,3 points, or without 3) with the MNIST dataset and I got no visible under-performance. Nevertheless, when using my own library for CNN [3], that for testing time, loads the saved model with more complicated images, and specially with only two classes, all the predictions were set to the class "1". Independently of the number of iteration or number of layers, I always got a 1 as the class, so while testing the performance given the ground truth, all the negative  and positive samples were classified as positive. The three points above solved this issue.

[1] https://www.tensorflow.org/api_docs/python/tf/contrib/layers/batch_norm
[2] https://r2rt.com/implementing-batch-normalization-in-tensorflow.html
[3] https://github.com/bsaldivaremc2/CNN_quick_2



  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)

  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)

Comentarios

  1. Is there an equivalent necessity in keras?

    ResponderEliminar
    Respuestas
    1. I don't know. I have used keras for training but never tried to export a model and test it separately.

      Eliminar

Publicar un comentario

Entradas populares de este blog

Deep neural networks visualization

Google Chart Sankey Diagram online for non software developers