CNN-based English-German Machine Translation
基于纯CNN的英德机器翻译模型(不使用Transformer架构)
项目特点
- 纯CNN架构:使用卷积神经网络进行序列到序列的翻译,不使用Transformer
- ConvS2S模型:基于Facebook的ConvS2S(Convolutional Sequence to Sequence)论文
- 位置编码:使用正弦位置编码为CNN提供序列位置信息
- GLU激活:使用门控线性单元(Gated Linear Unit)作为激活函数
- 卷积注意力:使用卷积层实现注意力机制,而非自注意力
模型架构
编码器(Encoder)
- 词嵌入层 + 位置编码
- 多层CNN编码器层
- 每层包含:
- 卷积层(kernel_size=3)
- GLU激活函数
- 残差连接
- 层归一化
- Dropout
解码器(Decoder)
- 词嵌入层 + 位置编码
- 多层CNN解码器层
- 每层包含:
- 因果卷积(保持自回归性质)
- GLU激活函数
- 卷积注意力机制
- 残差连接
- 层归一化
- Dropout
关键特性
- 因果卷积:解码器使用左侧padding实现因果性,确保生成时不看未来信息
- 卷积注意力:使用卷积层而不是点积注意力,保持纯CNN架构
- 位置编码:为CNN提供序列顺序信息(CNN是位置不变的)
安装依赖
pipinstall-rrequirements.txt额外依赖(需要手动安装):
python-mspacy download en_core_web_sm python-mspacy download de_core_news_sm数据准备
下载数据集
运行以下命令下载Multi30k数据集(英德翻译):
python data_loader.py这将自动下载并预处理Multi30k数据集,保存到./data目录。
数据集结构
data/ ├── train.en # 训练集英文 ├── train.de # 训练集德语 ├── valid.en # 验证集英文 ├── valid.de # 验证集德语 ├── test.en # 测试集英文 └── test.de # 测试集德语训练模型
基本训练
python train.py--batch_size32--epochs10--d_model256--n_layers6参数说明
--batch_size: 批大小(默认: 32)--epochs: 训练轮数(默认: 10)--lr: 学习率(默认: 0.001)--d_model: 模型维度(默认: 256)--n_layers: CNN层数(默认: 6)--kernel_size: 卷积核大小(默认: 3)--clip: 梯度裁剪阈值(默认: 1.0)--data_dir: 数据目录(默认: ./data)--save_dir: 模型保存目录(默认: ./models)--resume: 恢复训练的检查点路径
训练示例
# 完整训练python train.py\--batch_size64\--epochs20\--d_model512\--n_layers8\--kernel_size5\--lr0.0005\--save_dir./models/cnn_translator# 恢复训练python train.py\--resume./models/cnn_translator/checkpoint_epoch_10.pt\--epochs20模型推理
交互式翻译
python translate.py\--model_path./models/cnn_translator/checkpoint_epoch_20.pt\--interactive批量翻译
python translate.py\--model_path./models/cnn_translator/checkpoint_epoch_20.pt\--input_fileinput_sentences.txt\--output_filetranslations.txt示例翻译
训练完成后,运行translate.py会显示示例翻译:
英文: Hello, how are you? 德语: Hallo, wie geht es Ihnen? 英文: I love machine learning. 德语: Ich liebe maschinelles Lernen. 英文: This is a test sentence. 德语: Dies ist ein Testsatz.项目结构
cnn-translator/ ├── requirements.txt # 依赖包列表 ├── README.md # 项目文档 ├── model.py # CNN Seq2Seq模型定义 ├── data_loader.py # 数据加载和预处理 ├── train.py # 训练脚本 ├── translate.py # 推理脚本 ├── data/ # 数据集目录 │ ├── train.en │ ├── train.de │ ├── valid.en │ ├── valid.de │ ├── test.en │ └── test.de └── models/ # 模型检查点 ├── checkpoint_epoch_1.pt ├── checkpoint_epoch_2.pt └── ...模型性能
优势
- 并行计算:CNN可以完全并行化,训练速度快于RNN
- 梯度流:残差连接使得深层网络易于训练
- 局部特征:卷积擅长捕捉局部语言模式(n-gram特征)
局限性
- 长程依赖:相比Transformer,CNN捕捉长距离依赖能力较弱
- 计算效率:对于极长序列,卷积的计算量可能较大
参考资料
- ConvS2S论文:Convolutional Sequence to Sequence Learning (Facebook AI, 2017)
- GLU激活:Language Modeling with Gated Convolutional Networks
- 位置编码:基于Transformer的位置编码方案
常见问题
Q1: 为什么不用Transformer?
A: 本项目是学习和研究CNN用于机器翻译的实现,适合理解CNN在序列任务中的应用。
Q2: 模型训练很慢怎么办?
A:
- 减小
d_model或n_layers - 减小
batch_size - 使用GPU加速(
device='cuda')
Q3: 翻译质量不好怎么办?
A:
- 增加训练轮数
- 使用更大的
d_model(如512或768) - 增加
n_layers(如8或10) - 使用更大的数据集(如WMT14)
Q4: 如何保存和恢复训练?
A: 使用--resume参数指定检查点路径,训练会自动恢复。
许可证
MIT License
作者
CNN机器翻译实现 - 基于PyTorch
注意:这是一个研究/教育项目,生产环境建议使用成熟的NMT工具(如Fairseq、OpenNMT等)。