美文网首页Machine Learning & Recommendation & NLP & DL
自然语言处理N天-实现Transformer加载数据方法

自然语言处理N天-实现Transformer加载数据方法

作者: 我的昵称违规了 | 来源:发表于2019-03-05 11:50 被阅读4次
新建 Microsoft PowerPoint 演示文稿 (2).jpg

这个算是在课程学习之外的探索,不过希望能尽快用到项目实践中。在文章里会引用较多的博客,文末会进行reference。
搜索Transformer机制,会发现高分结果基本上都源于一篇论文Jay Alammar的《The Illustrated Transformer》(图解Transformer),提到最多的Attention是Google的《Attention Is All You Need》。

  • 对于Transformer的运行机制了解即可,所以会基于这篇论文来学习Transformer,结合《Sklearn+Tensorflow》中Attention注意力机制一章完成基本的概念学习;
  • 找一个基于Transformer的项目练手

5.代码实现

构建data_load

import tensorflow as tf
from utils import calc_num_batches

def load_vocab(vocab_fpath):
    '''
    加载词文件,返回一个idx<->token的图
    :param vocab_fpath: 字符串,词文件的地址  0: <pad>, 1: <unk>, 2: <s>, 3: </s>
    :return: 两个字典
    '''
    vocab=[line.split() for line in open(vocab_fpath,'r',encoding='utf-8').read().splitlines()]
    token2idx={token:idx for idx, token in enumerate(vocab)}
    idx2token={idx:token for idx, token in enumerate(vocab)}

    return token2idx, idx2token

def load_data(fpath1,fpath2,maxlen1,maxlen2):
    '''
    加载源语和目标语数据,筛选出最长的样例,用于生成掩码
    :param fpath1: 源语地址
    :param fpath2: 目标语地址
    :param maxlen1: 源语句子中最长的长度
    :param maxlen2: 目标语句子中最长的长度
    :return: 
    '''
    sents1, sents2 = [], []
    with open(fpath1, 'r') as f1, open(fpath2, 'r') as f2:
        for sent1, sent2 in zip(f1, f2):
            if len(sent1.split()) + 1 > maxlen1: continue  # 1: </s>
            if len(sent2.split()) + 1 > maxlen2: continue  # 1: </s>
            sents1.append(sent1.strip())
            sents2.append(sent2.strip())
    return sents1, sents2

def encode(inp, type, dict):
    '''
    将字符串转为数字,用于generator_fn。
    :param inp: 一维
    :param type: x表示源语,y表示目标语
    :param dict: token2idx字典
    :return: 数字列表
    '''
    inp_str=inp.decode("utf-8")
    if type=='x': tokens=inp_str.split()+["</s>"]
    else: tokens=["<s>"]+inp_str.split()+["</s>"]

    x=[dict.get(t,dict["<unk>"])for t in tokens]

    return x

def generator_fn(sents1, sents2, vocab_fpath):
    '''
    生成训练和评价数据
    :param sents1: 源语句子列表
    :param sents2: 目标句子列表
    :param vocab_fpath: 字符串,词文件地址
    '''
    token2idx, _ = load_vocab(vocab_fpath)
    for sent1, sent2 in zip(sents1, sents2):
        x = encode(sent1, "x", token2idx)
        y = encode(sent2, "y", token2idx)
        decoder_input, y = y[:-1], y[1:]

        x_seqlen, y_seqlen = len(x), len(y)
        yield (x, x_seqlen, sent1), (decoder_input, y, y_seqlen, sent2)

def input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=False):
    '''
    批量化数据
    :param sents1: 源语句子列表
    :param sents2: 目标句子列表
    :param vocab_fpath: 字符串,词文件地址
    batch_size: scalar
    shuffle: boolean

    Returns
    xs: tuple of
        x: int32 tensor. (N, T1)
        x_seqlens: int32 tensor. (N,)
        sents1: str tensor. (N,)
    ys: tuple of
        decoder_input: int32 tensor. (N, T2)
        y: int32 tensor. (N, T2)
        y_seqlen: int32 tensor. (N, )
        sents2: str tensor. (N,)
    '''
    shapes = (([None], (), ()),
              ([None], [None], (), ()))
    types = ((tf.int32, tf.int32, tf.string),
             (tf.int32, tf.int32, tf.int32, tf.string))
    paddings = ((0, 0, ''),
                (0, 0, 0, ''))

    dataset = tf.data.Dataset.from_generator(
        generator_fn,
        output_shapes=shapes,
        output_types=types,
        args=(sents1, sents2, vocab_fpath))  # <- arguments for generator_fn. converted to np string arrays

    if shuffle: # for training
        dataset = dataset.shuffle(128*batch_size)

    dataset = dataset.repeat()  # iterate forever
    dataset = dataset.padded_batch(batch_size, shapes, paddings).prefetch(1)

    return dataset

def get_batch(fpath1, fpath2, maxlen1, maxlen2, vocab_fpath, batch_size, shuffle=False):
    '''
    获取训练和评价小型数据
    fpath1: source file path. string.
    fpath2: target file path. string.
    maxlen1: source sent maximum length. scalar.
    maxlen2: target sent maximum length. scalar.
    vocab_fpath: string. vocabulary file path.
    batch_size: scalar
    shuffle: boolean

    Returns
    batches
    num_batches: number of mini-batches
    num_samples
    '''
    sents1, sents2 = load_data(fpath1, fpath2, maxlen1, maxlen2)
    batches = input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=shuffle)
    num_batches = calc_num_batches(len(sents1), batch_size)
    return batches, num_batches, len(sents1)

相关文章

网友评论

    本文标题:自然语言处理N天-实现Transformer加载数据方法

    本文链接:https://www.haomeiwen.com/subject/sfynuqtx.html