基于Pytorch的知识蒸馏(中文文本分类),将用bert训练好的中文分类模型蒸馏到bilstm上。使用的是hugging face上的bert-base-chinese,可去自行下载。知识蒸馏主要是将bert输出的logits中的知识蒸馏到bilstm中的logits。
在进行知识蒸馏的过程中,顺带着做了以下其它的实验,比如:梯度累加、混合精度训练、对抗训练。
目录结构:
--data:数据文件,使用的是THUCNews数据,共10类。
--config:存放的配置文件,里面可以控制训练、验证、测试、预测,也可以控制使用其它的一些策略。
--models;模型文件。主要存放bert和bilstm模型代码。
--checkpoints:模型保存的路径。
--processor:数据处理相关。对于bilstm而言,使用单个字作为输入,且使用整理好的5000个字的词汇表。在蒸馏的时候,既要处理数据为bert的格式,也要处理数据为biltm所需的格式。
--utils:存放辅助函数文件目录。主要包含了设置随机种子、日志模块(暂未使用),以及对抗训练所需的模块(FGM、PGD)。
--main:带main的python文件是主运行文件,每个文件名都标识着使用的策略。其中,main.py就是分别训练bert模型或bilstm模型,main_with_gradient_accumulation以梯度累加的形式训练bert或bilstm模型。main_with_apex.py在梯度累加下使用混合精度训练训练bert或bilstm模型。main_with_attck.py是可选混合精度+对抗训练来训练bert或bilstm模型。
需要注意的是每个文件中模型的名称可能需要对应修改一下,并且修改相关配置文件中的参数来使用不同的策略。直接运行python main.py
即可。
除了学习率和batch_size会有相应的变动,其余的参数都是一致的。当设置ga_step=4时,将学习率调整为原来的4倍,即2e-5*4=2e-8
,梯度累加后的batch_size相当于是32*4=128
。
学习率:2e-5 batch_size:32
模型 | accuray | precision | recall | macro_f1 |
---|---|---|---|---|
bert | 0.9451 | 0.9473 | 0.9451 | 0.9448 |
bert+ga(ga_step=4) | 0.9352 | 0.9366 | 0.9352 | 0.9344 |
bert+apex | 0.9496 | 0.9505 | 0.9496 | 0.9495 |
bert+apex_ga | 0.9471 | 0.9476 | 0.9471 | 0.9464 |
bert+fgm | 0.9479 | 0.9483 | 0.9479 | 0.9473 |
bert+pgd | 0.9479 | 0.9483 | 0.9479 | 0.9473 |
bilstm | 0.8983 | 0.9018 | 0.8983 | 0.8934 |
bilstm+ga | 0.9191 | 0.9209 | 0.9191 | 0.9172 |
bilstm+apex | 0.9001 | 0.9038 | 0.9001 | 0.8956 |
bilstm+apex_ga | 0.9154 | 0.9198 | 0.9154 | 0.9137 |
bilstm+fgm | 0.8983 | 0.9018 | 0.8983 | 0.8934 |
bilstm+pgd | 0.8983 | 0.9018 | 0.8983 | 0.8934 |
bilstm+fgm+apex | 0.8983 | 0.9018 | 0.8983 | 0.8934 |
学习率:2e-5 batch_size:32
bert->bisltm | accuray | precision | recall | macro_f1 |
---|---|---|---|---|
origin | 0.8983 | 0.9018 | 0.8983 | 0.8934 |
T=1 | 0.9019 | 0.9062 | 0.9019 | 0.8988 |
T=5 | 0.9001 | 0.9038 | 0.9001 | 0.8972 |
T=10 | 0.8983 | 0.9041 | 0.8983 | 0.8944 |
T=20 | 0.9010 | 0.9049 | 0.9010 | 0.8980 |
1、使用知识蒸馏确实能够将大模型的知识蒸馏到小模型上,在一些文献中表明在蒸馏的时候要做一些数据增强,这里暂时未做。
2、使用梯度累加能够隐式的增加batchsize,能够避免直接增加batchsize导致的GPU显存不够问题。顺带提一下batchsize也不是越大越好,会达到一个极限后再增大batchszie导致性能下降。
3、混合精度训练能够加速网络的训练,而且性能可能也会有提升。
4、对抗训练会增加训练的时间,结合混合精度训练更佳。PGD的训练时长较FGM的更长。这里奇怪的是FGM和PGD的效果一样,还有点问题。
5、数据蒸馏时温度T也是一个可好好调的参数。