美文网首页
Xception 算法

Xception 算法

作者: LuDon | 来源:发表于2018-12-03 14:58 被阅读35次

引言

Xception是google在inception之后提出的对inceptionV3的另一种改进,主要采用depthwise separable convolution来替换原来的inception v3中的卷积操作。

思考

要解决什么问题?怎么解决的?

  • 探寻Inception的基本思路
  • 从Inception发展历程的角度,理解其基本思想,并引入与Inception类似的Depthwise Separable Convolution结构。
  • 将Inception V3结构中的Inception改用Depthwise Separable Convolution。

效果怎么样?

  • 在与Inception V3参数数量相差无几的情况下,在ImageNet上性能有略微上升,JFT上有明显提高。

还存在什么问题?

  • Depthwise Separable Convolution不一定就是最优结构,还有尚未探索、验证的相似结构。

相关知识

Xception是inception系列的成员之一。
inception与普通的卷积操作相比,具有更强的表达能力。

inception系列的复习

inception结构如图1所示,用三种conv计算合并之后代替原来的conv。

图1.inception结构
选用卷积核1 图2.降维后的inception结构

两个3 \times 3的卷积核可以替代5 \times 5的卷积核,因此结构变为图3。

图3
以上模块主要在inceptionV3中,inceptionV3的基本结构为:
input
conv2d(32, 3, 3, s=2) #conv2d_1a
conv2d(32, 3, 3) #conv2d_2a
conv2d(64, 3, 3, 'SAME') #conv2d_2b
max_pool2d(3, 3, ,s=2) #maxpool_3a
conv2d(1, 1, 80) #conv2d_3b
conv2d(3, 3, 192) #conv2d_4a
max_pool2d(3, 3, s=2) #maxpool_5a

conv2d(1, 1, 64)     conv2d(1, 1, 48)      conv2d(1, 1, 64)    avgpool(3, 3)
                     conv2d(5, 5, 64)      conv2d(3, 3, 96)     conv2d(1, 1, 32)
                                           conv2d(3, 3, 96)

concat
*9
conv2d(1, 1, num_class)

在以上模块中,对于一个conv层来说,需要学习的是一个3D的卷积核,其中包括两个空间维度和一个通道维度,即w,h,c。这个卷积核与输入在3个维度上进行卷积操作,得到最终的结果,伪代码如下:

// 对于第i个filter
// 计算输入中心点(x, y)对应的卷积结果
sum = 0
for c in 1:C
  for h in 1:K
    for w in 1:K
      sum += input[c, y-K/2+h, x-K/2+w] * filter_i[c, h, w]
out[i, y, x] = sum

可以看出在3D的卷积中,通道这个维度与空间的两个维度是一样的。

先用一个统一的1 \times 1的卷积核卷积,然后连接三个3 \times 3的卷积核,如图4所示。这3个卷积操作只将前面的1 \times 1卷积结果中的一部分作为自己的输入。图中是将1/3通道作为每个卷积核的输入。

图4

再将3 \times 3卷积核的个数延伸到与1 \times 1卷积核输出通道的个数一样,即每个3 \times 3的卷积核和1个输入通道做卷积,如图5所示。

图5

Xception

Xception主要使用depthwise separable convolution,即将传统的卷积操作分成两步:

  • depthwise convolution
    M个3 \times 3的卷积核一对一卷积输入的M个特征图,不求和,生成M个结果。

    depthwist
  • pointwise convolution
    用N个1 \times 1的卷积核正常卷积前面生成的M个结果。

图6

depthwise separable convolution和以上结构的不同之处:

  • 操作的顺序不同。depthwise separable conv的实现是先使用channelwise的filter只在spatial dimension上做卷积,再使用1×1的卷积核做跨channel的融合。而Inception中先使用1×1的卷积核。
  • 非线性变换的缺席。在Inception中,每个conv操作后面都有ReLU的非线性变换,而depthwise separable conv没有。

Xception结构是将ResNet的相关卷积变成了depthwise separable conv,如下图所示。其中SeparableConv是depthwise separable conv模块。另外,原来的concat变成了residual connection。


Xception结构图

参考文献

[1] Xception: Deep Learning with Depthwise Separable Convolutions

代码分析

### Xception.py
from keras.preprocessing import image
from keras.models import Model
from keras import layers
from keras.layers import Dense
from keras.layers import Input
from keras.layers import BatchNormalization
from keras.layers import Activation
from keras.layers import Conv2D
from keras.layers import SeparableConv2D
from keras.layers import MaxPooling2D
from keras.layers import GlobalAveragePooling2D
import tensorflow as tf

input_tensor = tf.ones([1, 224, 224, 3])
input_shape = [224, 224, 3]
img_input = Input(tensor=input_tensor, shape=input_shape)

x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False, name='block1_conv1')(img_input) #(1, 112, 112, 32)
x = BatchNormalization(name='block1_conv1_bn')(x)
x = Activation('relu', name='block1_conv1_act')(x)
x = Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x)  #(1, 109, 109, 64)
x = BatchNormalization(name='block1_conv2_bn')(x)
x = Activation('relu', name='block1_conv2_act')(x)
 
residual = Conv2D(128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual) #(1, 55, 55, 128)
 
x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x)
x = BatchNormalization(name='block2_sepconv1_bn')(x)
x = Activation('relu', name='block2_sepconv2_act')(x)
x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x)
x = BatchNormalization(name='block2_sepconv2_bn')(x)
 
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block2_pool')(x)
x = layers.add([x, residual])
residual = Conv2D(256, (1, 1), strides=(2, 2),
                      padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
 
x = Activation('relu', name='block3_sepconv1_act')(x)
x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x)
x = BatchNormalization(name='block3_sepconv1_bn')(x)
x = Activation('relu', name='block3_sepconv2_act')(x)
x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x)
x = BatchNormalization(name='block3_sepconv2_bn')(x)
 
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block3_pool')(x)
x = layers.add([x, residual])
 
residual = Conv2D(728, (1, 1), strides=(2, 2),
                      padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
x = Activation('relu', name='block4_sepconv1_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x)
x = BatchNormalization(name='block4_sepconv1_bn')(x)
x = Activation('relu', name='block4_sepconv2_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x)
x = BatchNormalization(name='block4_sepconv2_bn')(x)
 
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block4_pool')(x)
x = layers.add([x, residual])
 
for i in range(8):
    residual = x
    prefix = 'block' + str(i + 5)
 
    x = Activation('relu', name=prefix + '_sepconv1_act')(x)
    x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv1')(x)
    x = BatchNormalization(name=prefix + '_sepconv1_bn')(x)
    x = Activation('relu', name=prefix + '_sepconv2_act')(x)
    x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv2')(x)
    x = BatchNormalization(name=prefix + '_sepconv2_bn')(x)
    x = Activation('relu', name=prefix + '_sepconv3_act')(x)
    x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv3')(x)
    x = BatchNormalization(name=prefix + '_sepconv3_bn')(x)
 
    x = layers.add([x, residual])
 
residual = Conv2D(1024, (1, 1), strides=(2, 2),
                      padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
 
x = Activation('relu', name='block13_sepconv1_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x)
x = BatchNormalization(name='block13_sepconv1_bn')(x)
x = Activation('relu', name='block13_sepconv2_act')(x)
x = SeparableConv2D(1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x)
x = BatchNormalization(name='block13_sepconv2_bn')(x)
 
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block13_pool')(x)
x = layers.add([x, residual])
 
x = SeparableConv2D(1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x)
x = BatchNormalization(name='block14_sepconv1_bn')(x)
x = Activation('relu', name='block14_sepconv1_act')(x)
 
x = SeparableConv2D(2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(x)
x = BatchNormalization(name='block14_sepconv2_bn')(x)
x = Activation('relu', name='block14_sepconv2_act')(x)
 
if include_top:
    x = GlobalAveragePooling2D(name='avg_pool')(x)
    x = Dense(classes, activation='softmax', name='predictions')(x)
else:
    if pooling == 'avg':
        x = GlobalAveragePooling2D()(x)
    elif pooling == 'max':
        x = GlobalMaxPooling2D()(x)
 
if input_tensor is not None:
    inputs = get_source_inputs(input_tensor)
 else:
    inputs = img_input
 
model = Model(inputs, x, name='xception')
 
if weights == 'imagenet':
    if include_top:
        weights_path = get_file('xception_weights_tf_dim_ordering_tf_kernels.h5',
                                    TF_WEIGHTS_PATH,
                                    cache_subdir='models')
    else:
        weights_path = get_file('xception_weights_tf_dim_ordering_tf_kernels_notop.h5',
                                    TF_WEIGHTS_PATH_NO_TOP,
                                    cache_subdir='models')
    model.load_weights(weights_path)
 if old_data_format:
    K.set_image_data_format(old_data_format)
return model

[1] 代码参考

相关文章

  • Xception 算法

    引言 Xception是google在inception之后提出的对inceptionV3的另一种改进,主要采用d...

  • 小型CNN总结:ShuffleNet、MobileNet v1,

    推荐的文章包括: ShuffleNet,mobilenet v1,v2,Xception Xception、Mob...

  • 经典卷积网络之Xception

    Xception模型 一、模型框架 Xception是谷歌公司继Inception后,提出的InceptionV3...

  • Xception

    Approach Two minor differences between and “extreme” vers...

  • Xception

    Xception是在Inception的基础上提出来的一种新网络,其对Inception中的Inception m...

  • Xception

    论文原文Xception: Deep Learning with Depthwise Separable Conv...

  • MobileNet系列

    MobileNet V1 思想 思想主要来源于Xception,Xception也是谷歌的作品,主要就是引入了se...

  • maven环境的配置问题,Unsupported major.m

    xception in thread "main" java.lang.UnsupportedClassVersi...

  • Inception、Xception

    如果 ResNet 是为了更深,那么 Inception 家族就是为了更宽。Inception 的作者对训练更大型...

  • 论文阅读 --- Xception

    关于图像分类的论文《Xception: Deep Learning with Depthwise Separabl...

网友评论

      本文标题:Xception 算法

      本文链接:https://www.haomeiwen.com/subject/glgjcqtx.html