想要千行代码搞定Transformer?这份高效的PaddlePaddle官方实现请收下
作者:媒体转发 时间:2019-04-09 16:09
(导语)想要做个神经机器翻译模型?想要做个强大的Transformer?搞定这千行PaddlePaddle代码你也可以。
目前,无论是从性能、结构还是业界应用上,Transformer 都有很多无可比拟的优势。本文将介绍Paddle Paddle 的Transformer项目,我们从项目使用到源码解析带你玩一玩NMT。只需千行模型代码,Transformer实现带回家。
其实PyTorch、TensorFlow等主流框架都有Transformer的实现,但如果我们需要将它们应用到产品中,还是需要修改很多。
例如谷歌大脑构建的Tensor2Tensor,它最开始是为了实现 Transformer,后来扩展到了各种任务。对于基于Tensor2Tensor实现翻译任务的用户,他们需要在10万+行TensorFlow代码找到需要的部分。
PaddlePaddle 提供的Transformer实现,项目代码只有2000+行,简洁优雅。如果我们使用大Batch Size,那么在预测速度上,PaddlePaddle复现的模型比TensorFlow官方使用tensor2tensor实现的模型还要快4倍。
项目地址:https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleNLP/neural_machine_translation/transformer
Transformer怎么用
相比此前 Seq2Seq 模型中广泛使用的循环神经网络,Transformer 使用深层注意力机制获得了更好的效果,目前大多数神经机器翻译模型都采用了这一网络结构。此外,不论是新兴的预训练语言模型,还是问答或句法分析,Transformer都展现出强大的建模能力。
相比传统NMT使用循环层或卷积层抽取文本信息,Transformer使用自注意力网络抽取并表征这些信息,下图对比了不同层级的特点:

图注:不同网络的主要性质,其中n表示序列长度、d为隐向量维度、k为卷积核大小。例如单层计算复杂度,一般句子长度n都小于隐向量维度d,那么自注意力层级的计算复杂度最小。
如上所示,Transformer使用的自注意力模型主要拥有以下优点,1)网络结构的计算复杂度最低;2)由于序列操作数复杂度低,模型的并行度很高;3)最大路径长度小,能够更好地表示长距离依赖关系;4)模型更容易训练。
现在,如果我们需要训练一个Transformer,那么最好的方法是什么?当然是直接跑已复现的模型了,下面我们将跑一跑PaddlePaddle 实现的Transformer。
处理数据
在Paddle的复现中,百度采用原论文测试的WMT'16 EN-DE 数据集,它是一个中等规模的数据集。这里比较方便的是,百度将数据下载和预处理等过程都放到了gen_data.sh脚本中,包括Tokenize 和 BPE 编码。
在这个项目中,我们既可以通过脚本预处理数据,也可以使用百度预处理好的数据集。首先最简单的方式是直接运行gen_data.sh脚本,运行后可以生成gen_data文件夹,该文件夹主要包含以下文件:

其中 wmt16_ende_data_bpe 文件夹包含最终使用的英德翻译数据。
如果我们从头下载并预处理数据,那么大概需要花1到2个小时完成预处理。为此,百度也提供了预处理好的WMT'16 EN-DE数据集,它包含训练、验证和测试所需要的BPE数据和字典。
其中,BPE策略会把稀疏词拆分为高频的子词,这样既能解决低频词无法训练的问题,也能合理降低词表规模。
如果不采用BPE的策略,要么词表的规模变得很大,从而使训练速度变慢或者显存太小而无法训练;要么一些低频词会当作未登录词处理,从而得不到训练。
预处理数据地址:https://transformer-res.bj.bcebos.com/wmt16_ende_data_bpe_clean.tar.gz
如果我们有其它数据集,例如中英翻译数据,也可以根据特定的格式进行定义。例如用空格分隔不同的token(对于中文而言需要提前用分词工具进行分词),用\t分隔源语言与目标语句对。
训练模型
如果需要执行模型训练,我们也可以直接运行训练主函数train.py。如下简要配置了数据路径以及各种模型参数:
# 显存使用的比例,显存不足可适当增大,最大为1
export FLAGS_fraction_of_gpu_memory_to_use=0.8
# 显存清理的阈值,显存不足可适当减小,最小为0,为负数时不启用
export FLAGS_eager_delete_tensor_gb=0.7
python -u train.py \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '' '' '' \
--train_file_pattern gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--token_delimiter ' ' \
--use_token_batch True \
--batch_size 1600 \
--sort_type pool \
--pool_size 200000 \
n_head 8 \
n_layer 4 \
d_model 512 \
d_inner_hid 1024 \
prepostprocess_dropout 0.3


