一、简单介绍
这是一篇中文纠错模型,来自今日头条公司。论文使用了一个错别字检测器detector和错别字纠正器corrector两个模型串联的方式进行错别字的检测和纠正。其中错别字检测器使用了一个普通的Bi-GRU网络,而错别字纠正器使用了业界的Bert encoder模块。
截图 (4).png
其中e_mask是符号'[MASK]'的embedding表示,而e_i是原始汉字的embedding表示。pi为detector检测为错别字的概率。所以这就是softMasked的含义。
二、模型结构
截图 (5).png
网络结构如上图所示,比较明确。detection网络是一个bi-GRU网络,输入为语句通过embedding之后的向量表示[max_length,hidden_size],其中max_length为语句的长度,一般会按照实际语句的最大值进行填充,比如512,但是实际训练的时候,考虑到训练效率可以设置的短一点,因为一般的语句也就是50字以内。输出是每个汉字的错别字概率。
中间是softmasking的过程,这个过程会参考检测器的结果尽量mask错别字部分。
将中间层的embedding输出输入到bert模型中,输出结果首先和输入的embedding求和叠加,最后再通过一个softmax层得到最终结果。
e = self.embedding(input_ids=input_ids, token_type_ids=segment_ids)
p = self.detector(e)
e_ = p * self.mask_e + (1-p) * e
_, _, _, _, _, head_mask, encoder_hidden_states, encoder_extended_attention_mask= self._init_inputs(input_ids, input_mask)
h = self.corrector(e_, attention_mask=encoder_extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask)
h = h[0] + e
out1 = self.linear(h)
out2 = self.softmax(out1)
三、损失函数
损失函数由两个部分组成,detector的损失函数以及纠正器的损失函数。其中,detector因为输出为二值概率,所以损失函数采用二元交叉熵损失函数。纠正器输出为多元交叉熵损失.两个损失函数按照固定比例进行加权得到最终的损失函数。
截图 (6).png
在实际使用的时候,对检测器直接使用nn.BCELoss()即可,对于纠正器则按照实际的情况使用损失函数
微信图片_20210105115009.jpg
四、实验效果
效果很好,需要注意到bert模型进行finetune之后的效果也不差,但是我现在还不是很清楚bert模型直接finetune的话,训练数据应该怎么标注,暂且先放在这里,以后补上。
截图 (7).png
实际使用的话,git上有两份开源代码,地址如下:
其中第一份代码,我自己尝试使用大约800万优质训练集进行训练,在corrector准确率为30%的时候,loss就不下降了,尝试修改过学习率和损失函数的参数都没有什么效果。暂时没有看出什么问题,先搁置一下。
后面一份代码的问题比较大,需要好好修改之后运行,目前还在训练中,有结果再记录吧。











网友评论