美文网首页机器学习与数据挖掘深度学习
2020自己动手写循环神经网(1)

2020自己动手写循环神经网(1)

作者: zidea | 来源:发表于2020-02-12 19:45 被阅读0次
machine_learning.jpg

循环神经网络

循环神经网

循环神经网络处理变长的序列,在这点上不同于普通神经网是。在时间序列定义一种递归关系来实现循环的。
S_k = f(S_{k-1} \cdot W_{rec} + X_k \cdot W_x)

  • S_k 表示在 k 时刻的状态(记忆)
  • X_k 表示在 k 时刻神经元的输入
  • W_{rec}W_x 是类似普通前馈神经网络的权重参数

循环神经网的隐含状态(记忆单元)会在整个时间序列上不断地被更新。在神经网隐藏层输出会作为被状态被记忆,在下一次作为输入一部分参与计算。我们可以认为循环神经网络是带有反馈回路(feedback loop),这里反馈回路意思没出计算输出作为状态存储 k 时刻,这个存储 k 时刻状态参与到下一次 k+1 时刻的计算。通过状态可以将信息在时间上序列进行传递。


rnn_simple_01.jpg

所以最终输出Y_k包含之前S_1,S_2,\dots,S_{k-1} 之前的状态。循环网络与前馈网络没有太大的区别,但是在训练集,我们训练是在早期是一个比较困难事情,因为损失函数很不稳定。这个随后会介绍一下。

自己手写一个简单的循环神经网,该循环神经网经过训练,可以计算在二进制(0 或 1)输入序列上计算出现 1 的个数,并在序列结束时输出总计数。如下图循环卷积有一个状态,在每一个时刻输入一个序列中的元素,计算输出后更新状态,最后输出 y

rnn_simple_02.png

下图左边为循环神经元的示例图,而右侧为其在时间序列上展开的效果。我们可将这个神经元进行展开,类似有共享参数w_recw_x的 n+1 层。

%matplotlib notebook
import sys
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import LogNorm
import seaborn as sns

sns.set_style('darkgrid')
np.random.seed(seed=1)

定义数据集

20 样本,每一个样本序列包含 10 二进制作为二进制。

# 创建数据集
nb_of_samples = 20
sequence_len = 10
# 创建序列
X = np.zeros((nb_of_samples, sequence_len))
# print(X)
for row_idx in range(nb_of_samples):
    X[row_idx,:] = np.around(np.random.rand(sequence_len)).astype(int)
t = np.sum(X, axis=1)
t
array([6., 3., 7., 4., 3., 4., 6., 7., 8., 4., 3., 4., 7., 7., 5., 5., 5.,
       4., 6., 4.])

在时间序列上训练方法

训练循环神经网络中训练算法是基于时间维度上的反向传播算法。循环神经网反向传播是基于我们在前馈网络的反向传播算法。只要掌握了神经网络的反向传播算法,基于时间的反向传播也不太难理解了。主要区别是,在一定的时间序列的循环网络需要按时间进行展开。在展开网络之后,就看到了我们熟悉的类似于普通的多层前馈网络的模型。区别是每个层有多个输入。

X_k \times W_x + S_k \times w_{rec}

def update_state(xk, sk, wx, wRec):
    """
    更新当前状态,当前状态当前 k 状态
    """
    return xk * wx + sk * wRec
print(X.shape[0],X.shape[1])
20 10
S = np.zeros((X.shape[0],X.shape[1]+1))
print(S.shape)
(20, 11)
for k in range(0,X.shape[1]):
    S[:,k+1] = update_state(X[:,k],S[:,k],1.2,1.2)

print(S)
[[ 0.          1.2         2.64        3.168       3.8016      4.56192
   6.674304    9.2091648  12.25099776 14.70119731 18.84143677]
 [ 0.          0.          1.2         2.64        3.168       3.8016
   5.76192     6.914304    8.2971648   9.95659776 11.94791731]
 [ 0.          0.          1.2         1.44        2.928       4.7136
   5.65632     7.987584   10.7851008  14.14212096 18.17054515]
 [ 0.          0.          1.2         2.64        3.168       5.0016
   6.00192     8.402304   10.0827648  12.09931776 14.51918131]
 [ 0.          0.          1.2         1.44        2.928       3.5136
   4.21632     5.059584    6.0715008   8.48580096 10.18296115]
 [ 0.          0.          0.          1.2         1.44        2.928
   3.5136      5.41632     6.499584    8.9995008  10.79940096]
 [ 0.          1.2         2.64        3.168       5.0016      6.00192
   8.402304   11.2827648  13.53931776 17.44718131 20.93661757]
 [ 0.          1.2         2.64        3.168       5.0016      7.20192
   9.842304   11.8107648  14.17291776 18.20750131 23.04900157]
 [ 0.          1.2         2.64        4.368       6.4416      7.72992
  10.475904   13.7710848  17.72530176 21.27036211 26.72443453]
 [ 0.          0.          1.2         2.64        3.168       3.8016
   4.56192     5.474304    7.7691648  10.52299776 12.62759731]
 [ 0.          0.          0.          1.2         1.44        1.728
   2.0736      3.68832     4.425984    5.3111808   7.57341696]
 [ 0.          0.          0.          0.          1.2         2.64
   3.168       5.0016      6.00192     8.402304   10.0827648 ]
 [ 0.          0.          1.2         2.64        4.368       6.4416
   7.72992    10.475904   12.5710848  16.28530176 20.74236211]
 [ 0.          1.2         1.44        2.928       3.5136      5.41632
   7.699584   10.4395008  13.72740096 17.67288115 21.20745738]
 [ 0.          1.2         1.44        1.728       3.2736      3.92832
   5.913984    8.2967808   9.95613696 11.94736435 15.53683722]
 [ 0.          0.          0.          1.2         2.64        3.168
   5.0016      6.00192     8.402304   10.0827648  13.29931776]
 [ 0.          0.          1.2         1.44        1.728       2.0736
   2.48832     4.185984    6.2231808   8.66781696 11.60138035]
 [ 0.          1.2         2.64        4.368       5.2416      7.48992
   8.987904   10.7854848  12.94258176 15.53109811 18.63731773]
 [ 0.          0.          1.2         1.44        2.928       4.7136
   5.65632     7.987584   10.7851008  12.94212096 16.73054515]
 [ 0.          0.          0.          0.          1.2         1.44
   2.928       3.5136      5.41632     6.499584    8.9995008 ]]
def forward_states(X, wx, wRec):
    """
    
    """
    #  初始化一个持有所有序列的状态的矩阵,初始化值为 0
    S = np.zeros((X.shape[0], X.shape[1]+1))
    #  使用上面定义 update_state 方法时间序列上状态
    for k in range(0, X.shape[1]):
        # S[k] = S[k-1] * wRec + X[k] * wx
        S[:,k+1] = update_state(X[:,k], S[:,k], wx, wRec)
    return S

最后希望大家关注我们微信公众号


wechat.jpeg

相关文章

网友评论

    本文标题:2020自己动手写循环神经网(1)

    本文链接:https://www.haomeiwen.com/subject/pencfhtx.html