美文网首页
SoftMaskedBert论文学习笔记

SoftMaskedBert论文学习笔记

作者: 欠我的都给我吐出来 | 来源:发表于2021-01-06 18:04 被阅读0次

一、简单介绍

这是一篇中文纠错模型,来自今日头条公司。论文使用了一个错别字检测器detector和错别字纠正器corrector两个模型串联的方式进行错别字的检测和纠正。其中错别字检测器使用了一个普通的Bi-GRU网络,而错别字纠正器使用了业界的Bert encoder模块。

论文的创新点在于SoftMaskedBert,传统的bert算法在输入训练的时候,会随机给15%的文字进行遮蔽mask操作,这样模型会通过注意力机制根据上下文学习到masked部分的内容。这个mask操作实在embedding之前进行的,意味着所有被随机选中mask的汉字都会使用符号'[MASK]'替换,然后在embedding的时候会被映射为同样的一个embedding表示e_mask(默认是一个768维的向量)。而sodtmasked含义则是,在进入corrector之前,首先根据检测器对每个字的判断结果进行遮蔽。检测器的输出是每个汉字是否为错别字的概率【0-1】之间,概率值越大表示越有可能是错别字,那么该汉字在输入bert模型的表示为 截图 (4).png
其中e_mask是符号'[MASK]'的embedding表示,而e_i是原始汉字的embedding表示。pi为detector检测为错别字的概率。所以这就是softMasked的含义。

二、模型结构

截图 (5).png

网络结构如上图所示,比较明确。detection网络是一个bi-GRU网络,输入为语句通过embedding之后的向量表示[max_length,hidden_size],其中max_length为语句的长度,一般会按照实际语句的最大值进行填充,比如512,但是实际训练的时候,考虑到训练效率可以设置的短一点,因为一般的语句也就是50字以内。输出是每个汉字的错别字概率。

中间是softmasking的过程,这个过程会参考检测器的结果尽量mask错别字部分。

将中间层的embedding输出输入到bert模型中,输出结果首先和输入的embedding求和叠加,最后再通过一个softmax层得到最终结果。

        e = self.embedding(input_ids=input_ids, token_type_ids=segment_ids)
        p = self.detector(e)
        e_ = p * self.mask_e + (1-p) * e
        _, _, _, _, _, head_mask, encoder_hidden_states, encoder_extended_attention_mask= self._init_inputs(input_ids, input_mask)
        h = self.corrector(e_, attention_mask=encoder_extended_attention_mask,                    head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask)
        h = h[0] + e
        out1 = self.linear(h)
        out2 = self.softmax(out1)

三、损失函数

损失函数由两个部分组成,detector的损失函数以及纠正器的损失函数。其中,detector因为输出为二值概率,所以损失函数采用二元交叉熵损失函数。纠正器输出为多元交叉熵损失.两个损失函数按照固定比例进行加权得到最终的损失函数。


截图 (6).png

在实际使用的时候,对检测器直接使用nn.BCELoss()即可,对于纠正器则按照实际的情况使用损失函数

微信图片_20210105115009.jpg

四、实验效果

效果很好,需要注意到bert模型进行finetune之后的效果也不差,但是我现在还不是很清楚bert模型直接finetune的话,训练数据应该怎么标注,暂且先放在这里,以后补上。


截图 (7).png

实际使用的话,git上有两份开源代码,地址如下:

  1. https://github.com/hiyoung123/SoftMaskedBert

  2. https://github.com/SonofGod-lucky/Pytorch_SoftMaskBert_SonofGod

其中第一份代码,我自己尝试使用大约800万优质训练集进行训练,在corrector准确率为30%的时候,loss就不下降了,尝试修改过学习率和损失函数的参数都没有什么效果。暂时没有看出什么问题,先搁置一下。

后面一份代码的问题比较大,需要好好修改之后运行,目前还在训练中,有结果再记录吧。

相关文章

网友评论

      本文标题:SoftMaskedBert论文学习笔记

      本文链接:https://www.haomeiwen.com/subject/jbkooktx.html