[关闭]
@Team 2019-01-20T11:19:25.000000Z 字数 6231 阅读 1539

《Computer vision》笔记-SqueezeNet(6)

石文华


1、概述:
具有相同精度的CNN模型,较小的CNN架构至少有以下三个优点:
(1)、在分布式训练期间,较小的CNN需要较少的服务器通信。
(2)、较小的CNN可以减少从云端下载模型的带宽。
(3)、较小的CNN更适合部署在FPGA和内存有限的硬件上。SqueezeNet是一种小型化的CNN架构,具备上述提到的几点优势,SqueezeNet在ImageNet上实现了AlexNet级精度,参数却减少了50倍,是AlexNet的1/510。 此外,通过模型压缩技术,我们能够将SqueezeNet压缩到小于0.5MB(比AlexNet小510倍)。
2、设计此网络采用的策略:
(1)用1x1卷积核替换3x3卷积核,通道数相同的情况下,1x1的卷积核参数要比3x3的卷积核减少9倍。
(2)减少输入特征图的通道数,因为普通卷积的卷积核是:长x宽x通道,使用瓶颈层减少通道数的话参数就自然少了很多。
(3)延迟下采样,以使卷积层具有较大的激活的特征图,更大的激活图保留了更多的信息,可以提供更高的分类准确率。
3、网络的核心模块(Fire模块)
首先是squeeze 卷积层组成(只有1x1过滤器),接着是由1x1和3x3两种卷积都有的expand层,如下图所示:

image.png-54.2kB
上图中,s1x1 = 3,e1x1 = 4,e3x3 = 4,s1x1表示的是squeeze层中1x1卷积的数量,e1x1表示的是expand层中1x1卷积的数量,e3x3表示的是expand层中3x3卷积的数量。它们三个是超参数,当我们使用Fire模块时,我们将s1x1设置为小于(e1x1 + e3x3),使得squeeze层有助于限制3x3过滤器的输入通道数量,也就是expand层的输入特征图的通道数。
4、网络结构
我们现在描述SqueezeNet CNN架构。 如下图可知(中图和右图使用了ResNet网络中的shortcut作为提升策略),SqueezeNet
从一个独立的卷积层(conv1)开始,然后是8个Fire模块(fire2-9),
最后一个卷积层(conv10)。 从网络的开始到结束,逐渐增加每个Fire模块的过滤器数量
。根据策略3,需要延迟下采样,所以SqueezeNet分别在图层conv1,fire4,fire8和conv10以步幅2执行最大池化,尽可能使卷积层具有较大的特征图;

image.png-136.9kB
有关SqueezeNet的详细信息和设计选择:
(1)1x1和3x3卷积核的输出激活具有相同的高度和宽度,因此需要对3x3卷积前的特征图进行填充,使输出特征图跟输入特征图大小相同。
(2)Fire模块里面的squeeze层和expand层采用Relu函数进行激活
(3)在Fire9模块之后应用Dropout,比例为50%,
(4)没有全连接层,而是采用全局平均池化。
(5)在训练SqueezeNet时,学习率从0.04开始,并且在整个训练中线性降低学习率。
5、超参数
在SqueezeNet中,每一个Fire module有3个维度的超参数,即s1x1 、 e1x1 和 e3x3。SqueezeNet一共有8个Fire modules,即一共24个超参数。下面两个是需要注意的比例关系:
(1)、SR:压缩比,即the squeeze ratio ,为squeeze层中filter个数除以Fire module中filter总个数得到的一个比例。
(2)、pct3x3:在expand层有1x1和3x3两种卷积,这里定义的参数是3x3卷积个占卷积总个数的比例。
分别测试SR与模型准确率以及模型大小的关系、pct3x3与模型准确率以及模型大小的关系。如下图可知,左图给出了压缩比(SR)的影响。压缩比小于0.25时,正确率开始显著下降。右图给出了3∗3卷积比例的影响,在比例小于25%时,正确率开始显著下降,此时模型大小约为原先的44%。超过50%后,模型大小显著增加,但是正确率不再上升。

image.png-91.2kB

6、代码如下:

  1. import keras
  2. from keras.models import Model
  3. from keras.layers import Input,Dense,Activation,Dropout,Flatten,GlobalAveragePooling2D
  4. from keras.layers import Conv2D,MaxPool2D,Concatenate
  5. def SqueezeNet(img_w,img_h,n_channels):
  6. #输入
  7. input_shape=(img_w,img_h,n_channels)
  8. #输入层
  9. img_input=Input(shape=input_shape,name='img_input')
  10. #第一个卷积层
  11. conv1=Conv2D(filters=96,kernel_size=(7,7),strides=(2,2),padding='same',activation='relu',name='conv1')(img_input)
  12. maxpool1=MaxPool2D(pool_size=(3,3),strides=(2,2),name='maxpool1')(conv1)
  13. #第一组fire
  14. fire1_squee=Conv2D(filters=16,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire1_squee')(maxpool1)
  15. fire1_expand1=Conv2D(filters=64,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire1_expand1')(fire1_squee)
  16. fire1_expand3=Conv2D(filters=64,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu',name='fire1_expand3')(fire1_squee)
  17. fire1=Concatenate(axis=-1)([fire1_expand1,fire1_expand3]) #合并
  18. #第二组fire
  19. fire2_squee=Conv2D(filters=16,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire2_squee')(fire1)
  20. fire2_expand1=Conv2D(filters=64,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire2_expand1')(fire2_squee)
  21. fire2_expand3=Conv2D(filters=64,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu',name='fire2_expand3')(fire2_squee)
  22. fire2=Concatenate(axis=-1)([fire2_expand1,fire2_expand3]) #合并
  23. #第三组fire
  24. fire3_squee=Conv2D(filters=32,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire3_squee')(fire2)
  25. fire3_expand1=Conv2D(filters=128,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire3_expand1')(fire3_squee)
  26. fire3_expand3=Conv2D(filters=128,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu',name='fire3_expand3')(fire3_squee)
  27. fire3=Concatenate(axis=-1)([fire3_expand1,fire3_expand3]) #合并
  28. #下采样
  29. maxpool2=MaxPool2D(pool_size=(3,3),strides=(2,2),name='maxpool2')(fire3)
  30. #第四组fire
  31. fire4_squee=Conv2D(filters=32,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire4_squee')(maxpool2)
  32. fire4_expand1=Conv2D(filters=128,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire4_expand1')(fire4_squee)
  33. fire4_expand3=Conv2D(filters=128,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu',name='fire4_expand3')(fire4_squee)
  34. fire4=Concatenate(axis=-1)([fire4_expand1,fire4_expand3]) #合并
  35. #第五组fire
  36. fire5_squee=Conv2D(filters=48,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire5_squee')(fire4)
  37. fire5_expand1=Conv2D(filters=192,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire5_expand1')(fire5_squee)
  38. fire5_expand3=Conv2D(filters=192,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu',name='fire5_expand3')(fire5_squee)
  39. fire5=Concatenate(axis=-1)([fire5_expand1,fire5_expand3]) #合并
  40. #第六组fire
  41. fire6_squee=Conv2D(filters=48,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire6_squee')(fire5)
  42. fire6_expand1=Conv2D(filters=192,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire6_expand1')(fire6_squee)
  43. fire6_expand3=Conv2D(filters=192,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu',name='fire6_expand3')(fire6_squee)
  44. fire6=Concatenate(axis=-1)([fire6_expand1,fire6_expand3]) #合并
  45. #第七组fire
  46. fire7_squee=Conv2D(filters=64,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire7_squee')(fire6)
  47. fire7_expand1=Conv2D(filters=256,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire7_expand1')(fire7_squee)
  48. fire7_expand3=Conv2D(filters=256,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu',name='fire7_expand3')(fire7_squee)
  49. fire7=Concatenate(axis=-1)([fire7_expand1,fire7_expand3]) #合并
  50. #下采样
  51. maxpool3=MaxPool2D(pool_size=(3,3),strides=(2,2),name='maxpool3')(fire7)
  52. #第八组fire
  53. fire8_squee=Conv2D(filters=64,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire8_squee')(maxpool3)
  54. fire8_expand1=Conv2D(filters=256,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire8_expand1')(fire8_squee)
  55. fire8_expand3=Conv2D(filters=256,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu',name='fire8_expand3')(fire8_squee)
  56. fire8=Concatenate(axis=-1)([fire8_expand1,fire8_expand3]) #合并
  57. conv2=Conv2D(filters=2,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='conv2')(fire8)
  58. Gap=GlobalAveragePooling2D()(conv2)
  59. model=Model(img_input,Gap)
  60. return model
  61. model=SqueezeNet(224,224,3)
  62. model.summary()

输出结果如下:
model.png-218.3kB

7、参考
https://blog.csdn.net/csdnldp/article/details/78648543
https://arxiv.org/pdf/1602.07360.pdf

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注