[关闭]
@Yano 2017-07-15T21:41:40.000000Z 字数 5823 阅读 2086

Caffe MNIST 简要分析

deeplearning


说明

要训练 MNIST,实际上只需要 3 个脚本文件即可完成:

  1. cd $CAFFE_ROOT
  2. ./data/mnist/get_mnist.sh
  3. ./examples/mnist/create_mnist.sh
  4. ./examples/mnist/train_lenet.sh

MNIST database,一个手写数字的图片数据库,每一张图片都是0到9中的单个数字。每一张都是抗锯齿(Anti-aliasing)的灰度图,图片大小28*28像素,数字部分被归一化为20*20大小,位于图片的中间位置,保持了原来形状的比例.

官方链接是没有提供jpg图片格式的,如果我们想对数据有个直观的了解,可以通过此链接下载 jpg 图片。

数据分布

MNIST数据库的来源是两个数据库的混合,一个来自Census Bureau employees(SD-3),一个来自high-school students(SD-1);有训练样本60000个,测试样本10000个.训练样本和测试样本中,employee和student写的都是各占一半.60000个训练样本一共大概250个人写的.训练样本和测试样本的来源人群没有交集.

MNIST数据库也保留了手写数字与身份的对应关系.

./data/mnist/get_mnist.sh

  1. #!/usr/bin/env sh
  2. # This scripts downloads the mnist data and unzips it.
  3. DIR="$( cd "$(dirname "$0")" ; pwd -P )"
  4. cd "$DIR"
  5. echo "Downloading..."
  6. for fname in train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte
  7. do
  8. if [ ! -e $fname ]; then
  9. wget --no-check-certificate http://yann.lecun.com/exdb/mnist/${fname}.gz
  10. gunzip ${fname}.gz
  11. fi
  12. done

这个脚本仅仅是把官网上的 mnist 的文件下载下来,其中 4 个文件分别是(测试数据和训练数据):

  1. train-images-idx3-ubyte
  2. train-labels-idx1-ubyte
  3. t10k-images-idx3-ubyte
  4. t10k-labels-idx1-ubyte

./examples/mnist/create_mnist.sh

  1. #!/usr/bin/env sh
  2. # This script converts the mnist data into lmdb/leveldb format,
  3. # depending on the value assigned to $BACKEND.
  4. set -e
  5. EXAMPLE=examples/mnist
  6. DATA=data/mnist
  7. BUILD=build/examples/mnist
  8. BACKEND="lmdb"
  9. echo "Creating ${BACKEND}..."
  10. rm -rf $EXAMPLE/mnist_train_${BACKEND}
  11. rm -rf $EXAMPLE/mnist_test_${BACKEND}
  12. $BUILD/convert_mnist_data.bin $DATA/train-images-idx3-ubyte \
  13. $DATA/train-labels-idx1-ubyte $EXAMPLE/mnist_train_${BACKEND} --backend=${BACKEND}
  14. $BUILD/convert_mnist_data.bin $DATA/t10k-images-idx3-ubyte \
  15. $DATA/t10k-labels-idx1-ubyte $EXAMPLE/mnist_test_${BACKEND} --backend=${BACKEND}
  16. echo "Done."

步骤:
- 删除 examples/mnist 目录下的 mnist_train_lmdbmnist_test_lmdb,其中BACKEND="lmdb"指定了数据格式,数据格式除了lmdb,还有leveldb
- 将第一步下载好的文件,转换成lmdb格式,保存在examples/mnist 目录下。

最终会出现下面的两个文件:

./examples/mnist/train_lenet.sh

  1. #!/usr/bin/env sh
  2. set -e
  3. ./build/tools/caffe train --solver=examples/mnist/lenet_solver.prototxt $@

使用编译好的caffe,用lenet_solver.prototxt这个proto文件的规则来训练,如果使用CPU的话,需要将最后一行的solver_mode改为CPU

  1. # The train/test net protocol buffer definition
  2. net: "examples/mnist/lenet_train_test.prototxt"
  3. # test_iter specifies how many forward passes the test should carry out.
  4. # In the case of MNIST, we have test batch size 100 and 100 test iterations,
  5. # covering the full 10,000 testing images.
  6. test_iter: 100
  7. # Carry out testing every 500 training iterations.
  8. test_interval: 500
  9. # The base learning rate, momentum and the weight decay of the network.
  10. base_lr: 0.01
  11. momentum: 0.9
  12. weight_decay: 0.0005
  13. # The learning rate policy
  14. lr_policy: "inv"
  15. gamma: 0.0001
  16. power: 0.75
  17. # Display every 100 iterations
  18. display: 100
  19. # The maximum number of iterations
  20. max_iter: 10000
  21. # snapshot intermediate results
  22. snapshot: 5000
  23. snapshot_prefix: "examples/mnist/lenet"
  24. # solver mode: CPU or GPU
  25. solver_mode: CPU

这里面是详细的参数,并且附有详细的注释~~~w(゚Д゚)w 其中最重要的应该是下面这句话:

  1. # The train/test net protocol buffer definition
  2. net: "examples/mnist/lenet_train_test.prototxt"

lenet_train_test.prototxt全部贴出来,不做深入分析,对于 proto 语法,可以参考:Google Protocol Buffers 数据交换协议

  1. name: "LeNet"
  2. layer {
  3. name: "mnist"
  4. type: "Data"
  5. top: "data"
  6. top: "label"
  7. include {
  8. phase: TRAIN
  9. }
  10. transform_param {
  11. scale: 0.00390625
  12. }
  13. data_param {
  14. source: "examples/mnist/mnist_train_lmdb"
  15. batch_size: 64
  16. backend: LMDB
  17. }
  18. }
  19. layer {
  20. name: "mnist"
  21. type: "Data"
  22. top: "data"
  23. top: "label"
  24. include {
  25. phase: TEST
  26. }
  27. transform_param {
  28. scale: 0.00390625
  29. }
  30. data_param {
  31. source: "examples/mnist/mnist_test_lmdb"
  32. batch_size: 100
  33. backend: LMDB
  34. }
  35. }
  36. layer {
  37. name: "conv1"
  38. type: "Convolution"
  39. bottom: "data"
  40. top: "conv1"
  41. param {
  42. lr_mult: 1
  43. }
  44. param {
  45. lr_mult: 2
  46. }
  47. convolution_param {
  48. num_output: 20
  49. kernel_size: 5
  50. stride: 1
  51. weight_filler {
  52. type: "xavier"
  53. }
  54. bias_filler {
  55. type: "constant"
  56. }
  57. }
  58. }
  59. layer {
  60. name: "pool1"
  61. type: "Pooling"
  62. bottom: "conv1"
  63. top: "pool1"
  64. pooling_param {
  65. pool: MAX
  66. kernel_size: 2
  67. stride: 2
  68. }
  69. }
  70. layer {
  71. name: "conv2"
  72. type: "Convolution"
  73. bottom: "pool1"
  74. top: "conv2"
  75. param {
  76. lr_mult: 1
  77. }
  78. param {
  79. lr_mult: 2
  80. }
  81. convolution_param {
  82. num_output: 50
  83. kernel_size: 5
  84. stride: 1
  85. weight_filler {
  86. type: "xavier"
  87. }
  88. bias_filler {
  89. type: "constant"
  90. }
  91. }
  92. }
  93. layer {
  94. name: "pool2"
  95. type: "Pooling"
  96. bottom: "conv2"
  97. top: "pool2"
  98. pooling_param {
  99. pool: MAX
  100. kernel_size: 2
  101. stride: 2
  102. }
  103. }
  104. layer {
  105. name: "ip1"
  106. type: "InnerProduct"
  107. bottom: "pool2"
  108. top: "ip1"
  109. param {
  110. lr_mult: 1
  111. }
  112. param {
  113. lr_mult: 2
  114. }
  115. inner_product_param {
  116. num_output: 500
  117. weight_filler {
  118. type: "xavier"
  119. }
  120. bias_filler {
  121. type: "constant"
  122. }
  123. }
  124. }
  125. layer {
  126. name: "relu1"
  127. type: "ReLU"
  128. bottom: "ip1"
  129. top: "ip1"
  130. }
  131. layer {
  132. name: "ip2"
  133. type: "InnerProduct"
  134. bottom: "ip1"
  135. top: "ip2"
  136. param {
  137. lr_mult: 1
  138. }
  139. param {
  140. lr_mult: 2
  141. }
  142. inner_product_param {
  143. num_output: 10
  144. weight_filler {
  145. type: "xavier"
  146. }
  147. bias_filler {
  148. type: "constant"
  149. }
  150. }
  151. }
  152. layer {
  153. name: "accuracy"
  154. type: "Accuracy"
  155. bottom: "ip2"
  156. bottom: "label"
  157. top: "accuracy"
  158. include {
  159. phase: TEST
  160. }
  161. }
  162. layer {
  163. name: "loss"
  164. type: "SoftmaxWithLoss"
  165. bottom: "ip2"
  166. bottom: "label"
  167. top: "loss"
  168. }

训练结果

最终的训练结果,准确率竟然高达 98.98% w(゚Д゚)w

测试自己的手写数字图片

在训练好模型之后,如何才能应用到实际当中呢?对于MNIST,就是要把符合格式的图片输入给神经网络,然后看预测是否符合标准。参考caffe笔记:测试自己的手写数字图片,用 caffe 测试自己的图片。在这里具体过程就不详细介绍,仅展示下最终结果(前面的链接写得非常全):

  1. ./build/examples/cpp_classification/classification.bin \
  2. > examples/mnist/classificat_net.prototxt \
  3. > examples/mnist/lenet_iter_10000.caffemodel \
  4. > examples/mnist/mean.binaryproto \
  5. > examples/mnist/label.txt \
  6. > examples/mnist/0.png

注意一下,classificat_net.prototxtexamples/mnist/lenet_train_test.prototxt的副本,但是要删掉

  1. layer {
  2. name: "mnist"
  3. type: "Data"
  4. top: "data"
  5. top: "label"
  6. include {
  7. phase: TRAIN
  8. }
  9. transform_param {
  10. scale: 0.00390625
  11. }
  12. data_param {
  13. source: "examples/mnist/mnist_train_lmdb"
  14. batch_size: 64
  15. backend: LMDB
  16. }
  17. }
  18. layer {
  19. name: "mnist"
  20. type: "Data"
  21. top: "data"
  22. top: "label"
  23. include {
  24. phase: TEST
  25. }
  26. transform_param {
  27. scale: 0.00390625
  28. }
  29. data_param {
  30. source: "examples/mnist/mnist_test_lmdb"
  31. batch_size: 100
  32. backend: LMDB
  33. }
  34. }

以及

  1. layer {
  2. name: "accuracy"
  3. type: "Accuracy"
  4. bottom: "ip2"
  5. bottom: "label"
  6. top: "accuracy"
  7. include {
  8. phase: TEST
  9. }
  10. }
  11. layer {
  12. name: "loss"
  13. type: "SoftmaxWithLoss"
  14. bottom: "ip2"
  15. bottom: "label"
  16. top: "loss"
  17. }

并在最后增加一层:

  1. layer {
  2. name: "prob"
  3. type: "Softmax"
  4. bottom: "ip2"
  5. top: "prob"
  6. }

最终的预测结果如下,可以看到模型100%确定数字是0,预测正确:

  1. ---------- Prediction for examples/mnist/0.png ----------
  2. 1.0000 - "0"
  3. 0.0000 - "1"
  4. 0.0000 - "3"
  5. 0.0000 - "4"
  6. 0.0000 - "2"

总结

对于模型及原理,网上已经有很多好的文章,这里不作介绍。今天在 centos 上安装 caffe,实在太费劲了…… 我在编译的时候,把 leveldb 和 lmdb 全部都取消了,这就导致 MNIST 的数据格式不能识别!解决方法:更改配置,重新编译。

参考了那么多的 caffe 安装教程,这篇文章最好:caffe 安装(centos 7)

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注