美文网首页深度学习研究所
tf.nn.conv2d中传入的第一二个参数的数据格式问题

tf.nn.conv2d中传入的第一二个参数的数据格式问题

作者: 西方失败9527 | 来源:发表于2017-09-17 15:48 被阅读0次

看到知乎上这样一个问题:

下面这个图所示,输入数据是一个2个通道3*3的数据,过滤器是一个具有两个通道的2*2的数据,按照一般卷积过程,即如果所示结果是一个通道的2*2的数据。

但是在tensorflow中,我们如下实现:

k = tf.constant([ 1,2 ,3,4,

                            5,6,7,8], dtype=tf.float32, name='k')

i = tf.constant([

                       1, 3, 5,

                       1, 3, 5,

                       1, 3, 5,

                        2, 4, 6,

                        2, 4, 6,

                        2, 4, 6

                        ], dtype=tf.float32, name='i')

kernel = tf.reshape(k, [2, 2, 2, 1], name='kernel')

image  = tf.reshape(i, [1, 3, 3, 2], name='image')

#res = tf.nn.conv2d(image, kernel, [1, 1, 1, 1], "VALID")

res = tf.squeeze(tf.nn.conv2d(image, kernel, [1, 1, 1, 1], "VALID"))# VALID means no padding

with tf.Session() as sess:

            print(sess.run(res))

结果不对

原因原来是data_format 参数的问题,图像数据格式定义了一批图片数据的存储顺序。在调用 TensorFlow API 时会经常看到 data_format 参数:

data_format 默认值为 "NHWC",也可以手动设置为 "NCHW"。这个参数规定了 input Tensor 和 output Tensor 的排列方式。

data_format 设置为 "NHWC" 时,排列顺序为 [batch, height, width, channels];

                      设置为 "NCHW" 时,排列顺序为 [batch, channels, height, width]。

其中 N 表示这批图像有几张,H 表示图像在竖直方向有多少像素,W 表示水平方向像素数,C 表示通道数(例如黑白图像的通道数 C = 1,而 RGB 彩色图像的通道数 C = 3)。为了便于演示,我们后面作图均使用 RGB 三通道图像。两种格式的区别如下图所示:

NCHW 中,C 排列在外层,每个通道内像素紧挨在一起,即 'RRRRRRGGGGGGBBBBBB' 这种形式。

NHWC 格式,C 排列在最内层,多个通道对应空间位置的像素紧挨在一起,即 'RGBRGBRGBRGBRGBRGB' 这种形式。

于是我们的程序中将数据顺序修改即可:

k = tf.constant([

1, 5,

2, 6,

3, 7,

4, 8

], dtype=tf.float32, name='k')

i = tf.constant([

1, 2, 3,

4, 5, 6,

1, 2, 3,

4, 5, 6,

1, 2, 3,

4, 5, 6

], dtype=tf.float32, name='i')

kernel = tf.reshape(k, [2, 2, 2, 1], name='kernel')

image  = tf.reshape(i, [1, 3, 3, 2], name='image')

#res = tf.nn.conv2d(image, kernel, [1, 1, 1, 1], "VALID")

res = tf.squeeze(tf.nn.conv2d(image, kernel, [1, 1, 1, 1], "VALID"))# VALID means no padding

with tf.Session() as sess:

            print(sess.run(image))

            print("------------------")

            print(sess.run(kernel))

            print("------------------")

           print(sess.run(res))

最终能如愿以偿得到如图右边的结果。不过feature map的172应该改为174,手算也该如此

主要参考:http://mp.weixin.qq.com/s/I4Q1Bv7yecqYXUra49o7tw

相关文章

  • tf.nn.conv2d中传入的第一二个参数的数据格式问题

    看到知乎上这样一个问题: 下面这个图所示,输入数据是一个2个通道3*3的数据,过滤器是一个具有两个通道的2*2的数...

  • 通过 iview 的checkbox 控制菜单树

    数据格式 后台返回数据格式为 样式 代码 iview 定义的函数可以传入自定义参数调用 数组中存放数字正常使用, ...

  • iOS: 延缓执行的几种方法

    一. performSelector /** 第一个参数:需要延迟执行的方法 第二个参数:要传入的参数(id类型...

  • 深入理解ES6三

    参数 ES6简化了为形式参数提供默认值的过程 上面的函数只有当不为第二个参数传入值或者主动为第二个参数传入unde...

  • 3.函数参数的默认值

    函数参数的默认值 ES6可以写成 如果只想传入第二个参数,第一个参数应该是 undefined

  • JS简单的日期格式化封装

    日期格式化 第一个参数传入标准时间第二个参数传入格式例如: 'yyyy-MM-DD' 'yyyy-MM-DD H...

  • AngularJS自定义指令

    使用.directive()方法来注册一个新指令传入两个参数,第一个参数传入一个字符串,作为指令的名字;第二个参数...

  • Shell脚本中的参数

    脚本中给的各种参数 $#:传入脚本的参数个数; $0: 脚本自身的名称; $1: 传入脚本的第一个参数; $2...

  • TensorFlow(3)CNN中的函数

    tf.nn.conv2d()函数 参数介绍: tf.nn.conv2d(input, filter, stride...

  • leetcode的每天一题更新(two sum)

    这个题目的问题是传入两个参数,第一个参数是一个数组,第二个是一个数字,如果数组中有两个数字相加等于第二个参数则将这...

网友评论

    本文标题: tf.nn.conv2d中传入的第一二个参数的数据格式问题

    本文链接:https://www.haomeiwen.com/subject/axiesxtx.html