@Team
2019-01-20T11:19:25.000000Z
字数 6231
阅读 1539
石文华
1、概述:
具有相同精度的CNN模型,较小的CNN架构至少有以下三个优点:
(1)、在分布式训练期间,较小的CNN需要较少的服务器通信。
(2)、较小的CNN可以减少从云端下载模型的带宽。
(3)、较小的CNN更适合部署在FPGA和内存有限的硬件上。SqueezeNet是一种小型化的CNN架构,具备上述提到的几点优势,SqueezeNet在ImageNet上实现了AlexNet级精度,参数却减少了50倍,是AlexNet的1/510。 此外,通过模型压缩技术,我们能够将SqueezeNet压缩到小于0.5MB(比AlexNet小510倍)。
2、设计此网络采用的策略:
(1)用1x1卷积核替换3x3卷积核,通道数相同的情况下,1x1的卷积核参数要比3x3的卷积核减少9倍。
(2)减少输入特征图的通道数,因为普通卷积的卷积核是:长x宽x通道,使用瓶颈层减少通道数的话参数就自然少了很多。
(3)延迟下采样,以使卷积层具有较大的激活的特征图,更大的激活图保留了更多的信息,可以提供更高的分类准确率。
3、网络的核心模块(Fire模块)
首先是squeeze 卷积层组成(只有1x1过滤器),接着是由1x1和3x3两种卷积都有的expand层,如下图所示:
上图中,s1x1 = 3,e1x1 = 4,e3x3 = 4,s1x1表示的是squeeze层中1x1卷积的数量,e1x1表示的是expand层中1x1卷积的数量,e3x3表示的是expand层中3x3卷积的数量。它们三个是超参数,当我们使用Fire模块时,我们将s1x1设置为小于(e1x1 + e3x3),使得squeeze层有助于限制3x3过滤器的输入通道数量,也就是expand层的输入特征图的通道数。
4、网络结构
我们现在描述SqueezeNet CNN架构。 如下图可知(中图和右图使用了ResNet网络中的shortcut作为提升策略),SqueezeNet
从一个独立的卷积层(conv1)开始,然后是8个Fire模块(fire2-9),
最后一个卷积层(conv10)。 从网络的开始到结束,逐渐增加每个Fire模块的过滤器数量
。根据策略3,需要延迟下采样,所以SqueezeNet分别在图层conv1,fire4,fire8和conv10以步幅2执行最大池化,尽可能使卷积层具有较大的特征图;
有关SqueezeNet的详细信息和设计选择:
(1)1x1和3x3卷积核的输出激活具有相同的高度和宽度,因此需要对3x3卷积前的特征图进行填充,使输出特征图跟输入特征图大小相同。
(2)Fire模块里面的squeeze层和expand层采用Relu函数进行激活
(3)在Fire9模块之后应用Dropout,比例为50%,
(4)没有全连接层,而是采用全局平均池化。
(5)在训练SqueezeNet时,学习率从0.04开始,并且在整个训练中线性降低学习率。
5、超参数
在SqueezeNet中,每一个Fire module有3个维度的超参数,即s1x1 、 e1x1 和 e3x3。SqueezeNet一共有8个Fire modules,即一共24个超参数。下面两个是需要注意的比例关系:
(1)、SR:压缩比,即the squeeze ratio ,为squeeze层中filter个数除以Fire module中filter总个数得到的一个比例。
(2)、pct3x3:在expand层有1x1和3x3两种卷积,这里定义的参数是3x3卷积个占卷积总个数的比例。
分别测试SR与模型准确率以及模型大小的关系、pct3x3与模型准确率以及模型大小的关系。如下图可知,左图给出了压缩比(SR)的影响。压缩比小于0.25时,正确率开始显著下降。右图给出了3∗3卷积比例的影响,在比例小于25%时,正确率开始显著下降,此时模型大小约为原先的44%。超过50%后,模型大小显著增加,但是正确率不再上升。
6、代码如下:
import keras
from keras.models import Model
from keras.layers import Input,Dense,Activation,Dropout,Flatten,GlobalAveragePooling2D
from keras.layers import Conv2D,MaxPool2D,Concatenate
def SqueezeNet(img_w,img_h,n_channels):
#输入
input_shape=(img_w,img_h,n_channels)
#输入层
img_input=Input(shape=input_shape,name='img_input')
#第一个卷积层
conv1=Conv2D(filters=96,kernel_size=(7,7),strides=(2,2),padding='same',activation='relu',name='conv1')(img_input)
maxpool1=MaxPool2D(pool_size=(3,3),strides=(2,2),name='maxpool1')(conv1)
#第一组fire
fire1_squee=Conv2D(filters=16,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire1_squee')(maxpool1)
fire1_expand1=Conv2D(filters=64,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire1_expand1')(fire1_squee)
fire1_expand3=Conv2D(filters=64,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu',name='fire1_expand3')(fire1_squee)
fire1=Concatenate(axis=-1)([fire1_expand1,fire1_expand3]) #合并
#第二组fire
fire2_squee=Conv2D(filters=16,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire2_squee')(fire1)
fire2_expand1=Conv2D(filters=64,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire2_expand1')(fire2_squee)
fire2_expand3=Conv2D(filters=64,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu',name='fire2_expand3')(fire2_squee)
fire2=Concatenate(axis=-1)([fire2_expand1,fire2_expand3]) #合并
#第三组fire
fire3_squee=Conv2D(filters=32,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire3_squee')(fire2)
fire3_expand1=Conv2D(filters=128,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire3_expand1')(fire3_squee)
fire3_expand3=Conv2D(filters=128,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu',name='fire3_expand3')(fire3_squee)
fire3=Concatenate(axis=-1)([fire3_expand1,fire3_expand3]) #合并
#下采样
maxpool2=MaxPool2D(pool_size=(3,3),strides=(2,2),name='maxpool2')(fire3)
#第四组fire
fire4_squee=Conv2D(filters=32,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire4_squee')(maxpool2)
fire4_expand1=Conv2D(filters=128,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire4_expand1')(fire4_squee)
fire4_expand3=Conv2D(filters=128,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu',name='fire4_expand3')(fire4_squee)
fire4=Concatenate(axis=-1)([fire4_expand1,fire4_expand3]) #合并
#第五组fire
fire5_squee=Conv2D(filters=48,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire5_squee')(fire4)
fire5_expand1=Conv2D(filters=192,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire5_expand1')(fire5_squee)
fire5_expand3=Conv2D(filters=192,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu',name='fire5_expand3')(fire5_squee)
fire5=Concatenate(axis=-1)([fire5_expand1,fire5_expand3]) #合并
#第六组fire
fire6_squee=Conv2D(filters=48,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire6_squee')(fire5)
fire6_expand1=Conv2D(filters=192,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire6_expand1')(fire6_squee)
fire6_expand3=Conv2D(filters=192,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu',name='fire6_expand3')(fire6_squee)
fire6=Concatenate(axis=-1)([fire6_expand1,fire6_expand3]) #合并
#第七组fire
fire7_squee=Conv2D(filters=64,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire7_squee')(fire6)
fire7_expand1=Conv2D(filters=256,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire7_expand1')(fire7_squee)
fire7_expand3=Conv2D(filters=256,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu',name='fire7_expand3')(fire7_squee)
fire7=Concatenate(axis=-1)([fire7_expand1,fire7_expand3]) #合并
#下采样
maxpool3=MaxPool2D(pool_size=(3,3),strides=(2,2),name='maxpool3')(fire7)
#第八组fire
fire8_squee=Conv2D(filters=64,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire8_squee')(maxpool3)
fire8_expand1=Conv2D(filters=256,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='fire8_expand1')(fire8_squee)
fire8_expand3=Conv2D(filters=256,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu',name='fire8_expand3')(fire8_squee)
fire8=Concatenate(axis=-1)([fire8_expand1,fire8_expand3]) #合并
conv2=Conv2D(filters=2,kernel_size=(1,1),strides=(1,1),padding='same',activation='relu',name='conv2')(fire8)
Gap=GlobalAveragePooling2D()(conv2)
model=Model(img_input,Gap)
return model
model=SqueezeNet(224,224,3)
model.summary()
输出结果如下:
7、参考
https://blog.csdn.net/csdnldp/article/details/78648543
https://arxiv.org/pdf/1602.07360.pdf