美文网首页
LSTM网络核心代码解读(一)

LSTM网络核心代码解读(一)

作者: CrazyCat_007 | 来源:发表于2019-02-12 01:15 被阅读0次

前言:

最近喜欢上研究LSTM网络,反复的分析它的结构,想深入去探究它的内部实现流程,下面是我对其核心代码的解读。

参数解读:

设批量样本(句子)数量batch_size = 100,每个句子的词语数量num_timesteps = 50(句子过长的截断,过短的做padding补全,可以使用'<UNK>'字符代替),每个词语的词向量num_embedding_size = 16,使用2层LSTM网络num_lstm_layers = 2,每层32个神经元num_lstm_nodes = [32, 32],全连接层为32个神经元num_fc_nodes = 32,用于控制梯度,防止梯度爆炸的参数clip_lstm_grads = 1.0,学习率learning_rate = 0.001.

部分核心代码:

embedding_initializer = tf.random_uniform_initializer(-1.0,1.0)

    with tf.variable_scope('embedding',initializer=embedding_initializer):

        embeddings = tf.get_variable('embedding',[vocab_size,hps.num_embedding_size],tf.float32)

        embed_inputs = tf.nn.embedding_lookup(embeddings,inputs)

    scale = 1.0/math.sqrt(hps.num_embedding_size+hps.num_lstm_nodes[-1])/3.0

    lstm_init = tf.random_uniform_initializer(-scale,scale)

    with tf.variable_scope('lstm_nn',initializer=lstm_init):

        cells = []

        for i in range(hps.num_lstm_layers):

            cell = tf.contrib.rnn.BasicLSTMCell(hps.num_lstm_nodes[i],state_is_tuple=True)

            cell = tf.contrib.rnn.DropoutWrapper(cell,output_keep_prob=keep_prob)

            cells.append(cell)

        cell = tf.contrib.rnn.MultiRNNCell(cells)

        initial_state = cell.zero_state(batch_size,tf.float32)

        rnn_outputs,state = tf.nn.dynamic_rnn(cell,embed_inputs,initial_state=initial_state)

        last = rnn_outputs[:,-1,:]

零碎知识点:

1.    tf.get_variable()、tf.Variable()和tf.placeholder()的区别:

tf.placeholder()是占位符,在构建计算图时起到占位的作用,不需要初始化,一般多应用于输入输出数据,具体数据应用于sess.run()中的feed_dict参数。

tf.get_variable()和tf.Variable()用于定义变量,一般需要初始化变量参数,它们的区别在于每个tf.Variable()都需要创建新的变量域,而在tf.get_variable()中可以共享或复用已经存在的变量参数。

2.    关于scale

类似于Xavier和He初始化方法,表示的是 1/(sqrt(in_nodes, out_nodes) * 3)。

核心代码解读

1.    embeddings = tf.get_variable('embedding',[vocab_size,hps.num_embedding_size],tf.float32)

       embed_inputs = tf.nn.embedding_lookup(embeddings,inputs)

初始化词向量。vocab_size为词表的大小,如词表中有1000个不同的单词(即1000行),每个单词对应一组词向量embedding,那么embeddings的大小即为1000*16。

inputs是输入的句子,假设一行句子有50个词,每个词在词表中对应的行数id(从0开始)为[10, 20, 30, ...],那么embed_inputs对应的即为[embeddings[10],embeddings[20],embeddings[30],...]。

2.    cell = tf.contrib.rnn.BasicLSTMCell(hps.num_lstm_nodes[i],state_is_tuple=True)

       cell = tf.contrib.rnn.DropoutWrapper(cell,output_keep_prob=keep_prob)

官方提示说tf.contrib.rnn.BasicLSTMCell()以后会被废弃,新的方法是tf.nn.rnn_cell.LSTMCell()。这里的cell即为一个LSTM网络,有32个神经元作为网络内部的隐藏神经元,官方建议状态参数state_is_tuple设置为True,用于将状态结果以元组的形式存储。输入数据进入LSTM网络后经过32个神经元进行全连接计算,为了防止过拟合,使用tf.contrib.rnn.DropoutWrapper()进行dropout,output_keep_prob是选择保留数据的比重,1.0为全部保留。

3.    cell = tf.contrib.rnn.MultiRNNCell(cells)

这里使用了两层LSTM网络,cells列表中包含了两个cell,tf.contrib.rnn.MultiRNNCell()自动将两层网络连接起来,将第一层的输出作为第二层的输入,将两层LSTM网络结合后相当于一层LSTM网络,输出为第二层网络的输出。

4.    initial_state = cell.zero_state(batch_size,tf.float32)

LSTM网络的状态值需要做初始化为全0,这里的initial_state的大小为100*32,cell的参数大小也为100*32,因为这里批量处理100个句子,LSTM网络的神经元为32个,每个句子对应32个状态值。

5.    rnn_outputs,state = tf.nn.dynamic_rnn(cell,embed_inputs,initial_state=initial_state)

       last = rnn_outputs[:,-1,:]

将每个句子的50个词语对应的embedding值依次输入到LSTM网络中,使用tf.nn.dynamic_rnn()动态地调整每个词在网络中对应32个神经元的权值参数,并将前一个词语输出的参数作为下一个词语的输入参数,最终输的是每个句子的最后一个词语的输出值和状态值。这里rnn_outputs的大小为100*50*32,state值大小为100*32。由于每个句子取的是最后一个词的输出,因为last值最终大小为100*32。

关于state值,我查阅了一些资料,发现state有两个值,c(cell state)和h(hidden state)。在每个batch完成后,c的值需要重置为0,而h为最终的状态值输出。对于多层LSTM网络来说,取state[-1].h即状态元组的最后一组值中的h其实是等同于last值的。

相关文章

  • LSTM网络核心代码解读(一)

    前言: 最近喜欢上研究LSTM网络,反复的分析它的结构,想深入去探究它的内部实现流程,下面是我对其核心代码的解读。...

  • lstm示例

    tensorflow下用LSTM网络进行时间序列预测 用LSTM做时间序列预测的思路,tensorflow代码实现...

  • 使用 bi-LSTM 对文本进行特征提取

    该部分内容通过代码注释的形式说明。 一、TextCNN 核心部分代码如下,这里主要关注 LSTM 类的内容。 二、...

  • keras lstm return sequence参数理解

    使用keras构建多层lstm网络时,除了最后一层lstm,中间过程的lstm中的return sequence参...

  • 详解 LSTM

    今天的内容有: LSTM 思路 LSTM 的前向计算 LSTM 的反向传播 关于调参 LSTM 长短时记忆网络(L...

  • Tensorflow神经网络之LSTM

    LSTM 简介 公式 LSTM LSTM作为门控循环神经网络因此我们从门控单元切入理解。主要包括: 输入门:It ...

  • 形象深刻理解lstm

    LSTM 长短时记忆网络(Long Short Term Memory Network, LSTM),是一种改进之...

  • LSTM网络

    之前和大家介绍了循环神经网络(RNN),RNN的魅力在于它能够很好地利用历史信息。例如,使用前一时刻的视频帧可以推...

  • NLP自然语言理解的学习

    RNN和LSTM网络结构 原文LSTM in pytorch Pytorch上的教学word_embeddings...

  • 第三次打卡-2020-02-16

    学习笔记 一、循环神经网络 LSTM 长短期记忆(Long short-term memory, LSTM)是一种...

网友评论

      本文标题:LSTM网络核心代码解读(一)

      本文链接:https://www.haomeiwen.com/subject/ymnceqtx.html