前言
马马虎虎学了一遍,面试时候让我tensorflow手写一个简单的前向传播反向训练过程,发现自己根本写不出来,好多函数用的不多,都是搬代码,天道好轮回,苍天饶过谁......
主要内容来源于《Tensorflow实战Google深度学习框架》,更注重tensorflow的基础。
以一个简单的为例,实现一个简单的神经网络单层结构。
实践过程如下
import tensorflow as tf
from numpy.random import RandomState
import warnings
warnings.filterwarnings('ignore')
定义 喂入的数据大小。我们可以一个一个的进行数据更新,也可以 一次性输入全部数据(内存放不下就会导致OOM)。
batch_size = 8
tf.placeholder()定义输入结点输出结点,使用这个定义的输入变量是需要输入的数据,也就是feed_dict阶段需要喂进去的
主要参数为 数据类型, shape=[ ] 输入张量的尺寸, name 命名。
当shape=[None,]为None时,即表示输入的数字待定,由具体喂入的数据所决定。结合训练中feed_dict部分一起看
x = tf.placeholder(tf.float32, shape=[None,2], name='x_input')
y_ = tf.placeholder(tf.float32, shape=[None,1], name='y_output')
定义参数变量,这个变量是需要训练的,随机初始化,然后在优化过程中不断被更改
tf.Variable()用来保存和更新神经网络中的参数。
tf.random_normal() 正太分布(主要参数为 平均值mean=,标准差stddev=,取值类型dtype=)
tf.truncated_normal() 正太分布,但是随机出来的值偏离平均值超过两个标准差,这个数则会被重新随机
tf.random_uniform() 均匀分布
tf.random_gamma() gamma分布
w1 = tf.Variable(tf.random_normal(shape=[2,1], stddev=1, seed=1))
tf.matmul()是matrix-multiply的简写,计算矩阵乘法。如果两个矩阵按位乘则用 * 即可。
y = tf.matmul(x,w1)
损失函数:神经网络的关键就是去优化损失函数。
这里相当于自定义损失,当loss_less = loss_more 就是正常无差别的
tf.greater(a,b)比较两个数大小
tf.where()有三个参数,第一个是条件,如果条件成立等于第二个参数 否则等于第三个参数。
loss_less = 10
loss_more = 1
loss = tf.reduce_sum(tf.where(tf.greater(y,y_),(y-y_)*loss_more,(y_-y)*loss_less))
优化方法:类似梯度下降等一些列,只不过这个 比梯度下降更好用。之后专门写一篇优化算法的。
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
随机产生数据。
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)
设置回归的正确值为:Y=2*X1+X2+C
Y = [[2*x1+x2+rdm.rand()/10.0-0.05] for (x1,x2) in X]
训练神经网络:
1.初始化变量
init_op = tf.global_variables_initializer()
sess.run(init_op)
2.设置喂入的数据
3.开始优化train 并 喂入数据。
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
STEPS = 10000
for i in range(STEPS):
start = (i* batch_size) % dataset_size
end = min(start+batch_size,dataset_size)
sess.run(train_step,
feed_dict={x:X[start:end],y_:Y[start:end]})
if i % 1000 == 0:
print("当前轮数 {}".format(i))
print("参数更新为{}".format(sess.run(w1)))
当前轮数 0
参数更新为[[-0.81031823]
[ 1.4855988 ]]
当前轮数 1000
参数更新为[[0.11504051]
[2.3188334 ]]
当前轮数 2000
参数更新为[[0.77901006]
[2.6848948 ]]
当前轮数 3000
参数更新为[[1.2026749]
[2.6662533]]
当前轮数 4000
参数更新为[[1.4926519]
[2.360118 ]]
当前轮数 5000
参数更新为[[1.7392459]
[1.8202068]]
当前轮数 6000
参数更新为[[1.9628013]
[1.1227087]]
当前轮数 7000
参数更新为[[2.0193663]
[1.0427792]]
当前轮数 8000
参数更新为[[2.0194335]
[1.0438062]]
当前轮数 9000
参数更新为[[2.0200043]
[1.043436 ]]
网友评论