知识蒸馏(Knowledge Distillation)(「知识」是什么

大家好,又见面了,我是你们的朋友风君子。

本文主要罗列与知识蒸馏相关的一些算法与应用。但首先需要明确的是,教师网络或给定的预训练模型中包含哪些可迁移的知识?基于常见的深度学习任务,可迁移知识列举为:

  • 中间层特征:浅层特征注重纹理细节,深层特征注重抽象语义;
  • 任务相关知识:如分类概率分布,目标检测涉及的实例语义、位置回归信息等;
  • 表征相关知识:强调特征表征能力的迁移,相对通用、任务无关(Task-agnostic);

知识蒸馏(Knowledge Distillation)

1、Distilling the Knowledge in a Neural Network

Hinton的文章”Distilling the Knowledge in a Neural Network”首次提出了知识蒸馏(暗知识提取)的概念,通过引入与教师网络(Teacher network:复杂、但预测精度优越)相关的软目标(Soft-target)作为Total loss的一部分,以诱导学生网络(Student network:精简、低复杂度,更适合推理部署)的训练,实现知识迁移(Knowledge transfer)。

知识蒸馏(Knowledge Distillation)

如上图所示,教师网络(左侧)的预测输出除以温度参数(Temperature)之后、再做Softmax计算,可以获得软化的概率分布(软目标或软标签),数值介于0~1之间,取值分布较为缓和。Temperature数值越大,分布越缓和;而Temperature数值减小,容易放大错误分类的概率,引入不必要的噪声。针对较困难的分类或检测任务,Temperature通常取1,确保教师网络中正确预测的贡献。硬目标则是样本的真实标注,可以用One-hot矢量表示。Total loss设计为软目标与硬目标所对应的交叉熵的加权平均(表示为KD loss与CE loss),其中软目标交叉熵的加权系数越大,表明迁移诱导越依赖教师网络的贡献,这对训练初期阶段是很有必要的,有助于让学生网络更轻松的鉴别简单样本,但训练后期需要适当减小软目标的比重,让真实标注帮助鉴别困难样本。另外,教师网络的预测精度通常要优于学生网络,而模型容量则无具体限制,且教师网络推理精度越高,越有利于学生网络的学习。

教师网络与学生网络也可以联合训练,此时教师网络的暗知识及学习方式都会影响学生网络的学习,具体如下(式中三项分别为教师网络Softmax输出的交叉熵loss、学生网络Softmax输出的交叉熵loss、以及教师网络数值输出与学生网络Softmax输出的交叉熵loss):

联合训练的Paper地址:https://arxiv.org/abs/1711.05852

知识蒸馏(Knowledge Distillation)

知识蒸馏(Knowledge Distillation)

2、Exploring Knowledge Distillation of Deep Neural Networks for Efficient Hardware Solutions

GitHub地址:https://github.com/peterliht/knowledge-distillation-pytorch

这篇文章将Total loss重新定义如下:

知识蒸馏(Knowledge Distillation)

Total loss的PyTorch代码如下,引入了精简网络输出与教师网络输出的KL散度,并在诱导训练期间,先将Teacher network的预测输出缓存到CPU内存中,可以减轻GPU显存的Overhead:

def loss_fn_kd(outputs, labels, teacher_outputs, params):
    """
    Compute the knowledge-distillation (KD) loss given outputs, labels.
    "Hyperparameters": temperature and alpha
    NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
    and student expects the input tensor to be log probabilities! See Issue #2
    """
    alpha = params.alpha
    T = params.temperature
    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
                             F.cross_entropy(outputs, labels) * (1. - alpha)

    return KD_loss

3、Ensemble of Multiple Teachers

Paper地址: Efficient Knowledge Distillation from an Ensemble of Teachers | Request PDF

第一种算法:多个教师网络输出的Soft label按加权组合,构成统一的Soft label,然后指导学生网络的训练:

知识蒸馏(Knowledge Distillation)

知识蒸馏(Knowledge Distillation)

第二种算法:由于加权平均方式会弱化、平滑多个教师网络的预测结果,因此可以随机选择某个教师网络的Soft label作为Guidance:

知识蒸馏(Knowledge Distillation)

第三种算法:同样地,为避免加权平均带来的平滑效果,首先采用教师网络输出的Soft label重新标注样本、增广数据、再用于模型训练,该方法能够让模型学会从更多视角观察同一样本数据的不同功能:

知识蒸馏(Knowledge Distillation)

4、Hint-based Knowledge Transfer

Paper地址:https://arxiv.org/abs/1412.6550

GitHub地址:https://github.com/adri-romsor/FitNets

为了能够诱导训练更深、更紧凑的学生网络(Deeper and thinner FitNet),需要考虑教师网络中间层的Feature Maps(作为Hint),用来指导学生网络中相应的Guided layer。此时需要引入L2 loss指导训练过程,该loss计算为教师网络Hint layer与学生网络Guided layer输出Feature Maps之间的差别,若二者输出的Feature Maps形状不一致,Guided layer需要通过一个额外的回归层,具体如下:

知识蒸馏(Knowledge Distillation)

知识蒸馏(Knowledge Distillation)

具体训练过程分两个阶段完成:第一个阶段利用Hint-based loss诱导学生网络达到一个合适的初始化状态(只更新W_Guided与W_r);第二个阶段利用教师网络的soft label指导整个学生网络的训练(即知识蒸馏),且Total loss中Soft target相关部分所占比重逐渐降低,从而让学生网络能够全面辨别简单样本与困难样本(教师网络能够有效辨别简单样本,而困难样本则需要借助真实标注,即Hard target):

知识蒸馏(Knowledge Distillation)

5、Attention to Attention Transfer

Paper地址:https://arxiv.org/abs/1612.03928

GitHub地址:https://github.com/szagoruyko/attention-transfer

通过网络中间层的Attention map,完成Teacher network与Student network之间的知识迁移。考虑给定的Tensor A,基于Activation的Attention map可以定义为如下三种之一:

知识蒸馏(Knowledge Distillation)

随着网络层次的加深,关键区域的Attention-level也随之提高。文章最后采用了第二种形式的Attention map,取p=2,并且Activation-based attention map的知识迁移效果优于Gradient-based attention map,loss定义及迁移过程如下:

知识蒸馏(Knowledge Distillation)

知识蒸馏(Knowledge Distillation)

知识蒸馏(Knowledge Distillation)

6、Flow of the Solution Procedure

Paper地址:

http://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf

暗知识亦可表示为训练的求解过程(FSP: Flow of the Solution Procedure),教师网络或学生网络的FSP矩阵定义如下(Gram形式的矩阵):

知识蒸馏(Knowledge Distillation)

知识蒸馏(Knowledge Distillation)

训练的第一阶段:最小化教师网络FSP矩阵与学生网络FSP矩阵之间的L2 Loss,初始化学生网络的可训练参数:

知识蒸馏(Knowledge Distillation)

训练的第二阶段:在目标任务的数据集上fine-tune学生网络。从而达到知识迁移、快速收敛、以及迁移学习的目的。

7、Knowledge Distillation with Adversarial Samples Supporting Decision Boundary

Paper地址:https://arxiv.org/abs/1805.05532

从分类的决策边界角度分析,知识迁移过程亦可理解为教师网络诱导学生网络有效鉴别决策边界的过程,鉴别能力越强意味着模型的泛化能力越好:

知识蒸馏(Knowledge Distillation)

文章首先利用对抗攻击策略(Adversarial attacking)将基准类样本(Base class sample)转为目标类样本、且位于决策边界附近(BSS: boundary supporting sample),进而利用对抗生成的样本诱导学生网络的训练,可有效提升学生网络对决策边界的鉴别能力。文章采用迭代方式生成对抗样本,需要沿Loss function(基准类得分与目标类得分之差)的梯度负方向调整样本,直到满足停止条件为止:

知识蒸馏(Knowledge Distillation)

Loss function定义如下:

知识蒸馏(Knowledge Distillation)

沿Loss function的梯度负方向调整样本:

知识蒸馏(Knowledge Distillation)

停止条件(只要满足三者之一):

知识蒸馏(Knowledge Distillation)

结合对抗生成的样本,利用教师网络训练学生网络所需的Total loss包含CE loss、KD loss以及Boundary supporting loss(BS loss):

知识蒸馏(Knowledge Distillation)

8、Label Refinery:Improving ImageNet Classification through Label Progression

GitHub地址:https://github.com/hessamb/label-refinery

这篇文章通过迭代式的诱导训练,主要解决训练期间样本的Crop与Label不一致的问题,以增强Label的质量,从而进一步增强模型的泛化能力:

知识蒸馏(Knowledge Distillation)

诱导过程中,Total loss表示为本次迭代(t>1)网络的预测输出(概率分布)与上一次迭代输出(Label Refinery:类似于教师网络的角色)的KL散度:

知识蒸馏(Knowledge Distillation)

文章实验部分表明,不仅可以用训练网络作为Label Refinery Network,也可以用其他高质量网络(如Resnet50)作为Label Refinery Network。并在诱导过程中,能够对抗生成样本,实现数据增强。

9、Meal V2 KD (Ensemble of Multi-Teachers)

Paper地址:https://arxiv.org/abs/2009.08453

GitHub:https://github.com/szq0214/MEAL-V2

MEAL V2的基本思路是通过知识蒸馏,将多个Teacher模型的效果ensemble、迁移到一个Student模型中,包括:Teacher模型集成,KL散度loss以及判别器:

  • 多个Teacher的预测概率求平均;
  • 仅依靠Teacher的Soft label;
  • 判别器起到正则化作用;
  • Student从预训练模型开始,减少蒸馏训练的开销;

知识蒸馏(Knowledge Distillation)

10、KD for Lightweight Face Detector

Paper地址:

https://www.researchgate.net/publication/339172272_Learning_Lightweight_Face_Detector_with_Knowledge_Distillation

人脸检测模型的分类预测输出,属于典型的二分类(0-背景,1-人脸)。针对人脸检测模型,通常认为在教师网络与学生网络之间,Classification map的差异,要比Regression map更大;且Classification map提供的Soft label,更容易作为监督信息。另外,需要基于教师网络与学生网络输出得分的差异,过滤简单样本、实现在线难例挖掘,为学生网络的学习提供有效监督。

Loss function的实现如下:

知识蒸馏(Knowledge Distillation)

知识蒸馏(Knowledge Distillation)

def kd_loss(teacher_output, student_output, alpha=50.0):
    teacher_output = F.softmax(teacher_output, dim=-1)
    student_output = F.softmax(student_output, dim=-1)

    scale = 16.2  # when beta=6.4 and gamma=3.2
    beta = 6.4

    threshold = scale * torch.pow(torch.abs(teacher_output[:, :, 1] - 0.5), beta)

    mask = teacher_output[:, :, 1] > threshold

    t_feat = teacher_output[mask]
    f_feat = student_output[mask]
    loss = torch.nn.functional.mse_loss(t_feat, f_feat)
    return loss * alpha

11、Knowledge Distillation meets Self-supervision

Paper地址:https://arxiv.org/pdf/2006.07114.pdf

GitHub:GitHub – xuguodong03/SSKD

SSKD(using Self-Supervised learning as an auxiliary task for Knowledge Distillation )将自监督学习作为辅助任务,以执行知识蒸馏。在传统KD中,学生网络模仿教师网络关于任务层的预测输出(如分类、位置回归等);而在SSKD中,在变换后的数据集和自监督辅助任务上,能够实现更为丰富的结构化知识迁移。由于对比学习(Contrastive learning)在自监督学习中表现优秀,SSKD选择对比学习作为自监督辅助任务。对比学习通过使网络区分正负样本,最大化每个样本变换前后的相似度(基于Contrastive loss),使得模型学习到具有变换不变性的表征能力。文章使用余弦函数以衡量不同表征之间的相似度,并构造相似度矩阵;然后通过Cross entropy loss,以衡量变换后样本与某一原样本是否为正样本对:

# 4 means one original sample and three augmented samples
batch = int(x.size(0) / 4)
nor_index = (torch.arange(4*batch) % 4 == 0).cuda()
aug_index = (torch.arange(4*batch) % 4 != 0).cuda()

# rep is the representation features of all samples
nor_rep = rep[nor_index]
aug_rep = rep[aug_index]
nor_rep = nor_rep.unsqueeze(2).expand(-1, -1, 3*batch).transpose(0, 2)
aug_rep = aug_rep.unsqueeze(2).expand(-1, -1, 1*batch)
# cosine similarity is used for similarity matrix
simi = F.cosine_similarity(aug_rep, nor_rep, dim=1)
target = torch.arange(batch).unsqueeze(1).expand(-1, 3).contiguous().view(-1).long().cuda()
loss = F.cross_entropy(simi, target)

SSKD一方面要求学生学习任务相关知识(无论是正样本还是负样本的任务预测),另一方面要求学生模仿教师的特征表征能力,以有效区分正负样本(教师网络需要先执行对比学习,获取表征能力以供迁移),总的Loss function表示如下

知识蒸馏(Knowledge Distillation)

知识蒸馏(Knowledge Distillation)

知识蒸馏(Knowledge Distillation)

其中Lce表示Hard label driven loss,Lkd表示原样本的Hinton KD loss,Lss表示表征知识的迁移,LT表示变换样本的Hinton KD;B表示相似度矩阵的Softmax计算,表明学生需要学习教师预测的样本对相似度概率分布。在具体训练过程中,通过OHEM挖掘高质量的变换样本用于计算LT与Lss,排序依据分别为教师的Soft-label与相似度矩阵。其中LT的OHEM如下:

aug_target = target.unsqueeze(1).expand(-1, 3).contiguous().view(-1).long().cuda()
rank = torch.argsort(aug_knowledge, dim=1, descending=True)
rank = torch.argmax(torch.eq(rank, aug_target.unsqueeze(1)).long(), dim=1)  # groundtruth label's rank
index = torch.argsort(rank)
tmp = torch.nonzero(rank, as_tuple=True)[0]
wrong_num = tmp.numel()
correct_num = 3 * batch - wrong_num
wrong_keep = int(wrong_num * args.ratio_tf)
index = index[:correct_num + wrong_keep]
distill_index_tf = torch.sort(index)[0]

SSKD非常适合标注不充分(对于无标注场景,可将Lce移除)、及Few-shot应用场景(依靠对比学习迁移表征知识)。另外,由于SSKD仅依靠Final layer迁移知识,因此也适合异质网络的诱导训练。SSKD的具体应用框架如下图,教师网络与学生网络均由三部分组成:用于提取特征的Backbone,用于主任务的分类器以及用于辅助任务的自监督模块:

知识蒸馏(Knowledge Distillation)

12、Contrastive Pruning

Paper地址:https://arxiv.org/abs/2112.07198

GitHub:GitHub – RunxinXu/ContrastivePruning: Source code for our AAAI’22 paper 《From Dense to Sparse: Contrastive Pruning for Better Pre-trained Language Model Compression》

在下游任务的微调过程中,执行BERT剪枝与知识蒸馏;KD核心思想包括:

  • Teacher models包括预训练模型(强调任务无关的表征知识)、微调过的模型(强调任务相关知识)、以及剪枝过程中保存的模型(历史剪枝模型的信息);处于剪枝状态的模型,作为Student model;
  • 知识蒸馏同时包含任务相关KD(Soft-label)、与任务无关KD(自监督、对比损失);

知识蒸馏(Knowledge Distillation)

知识蒸馏(Knowledge Distillation)

13、Vision-language Knowledge Distillation

Paper地址:https://arxiv.org/abs/2203.06386

为了增强多模型生成任务的效果,以CLIP作为Teacher model,以BART (Encoder-decoder)作为Student model,实现了多模态表征知识的迁移:

  • CLIP包含Image encoder与Text encoder,具备统一、共享的多模态表征空间,能够实现视觉表征与文本表征的对齐;
  • 为了实现CLIP表征知识的迁移,引入了TTDL (Text-Text Distance Minimization)、ITCL (Image-Text Contrastive Learning)与ICTI (Image-Conditioned Text Infilling)三个迁移任务;迁移训练期间,CLIP主干参数冻结;
    • TTDL (L2 distance loss):

      知识蒸馏(Knowledge Distillation)

    • ITCL (InforNCE loss, used in contrastive learning):

      知识蒸馏(Knowledge Distillation)

      知识蒸馏(Knowledge Distillation)

      知识蒸馏(Knowledge Distillation)

    • ICTI (Sum of log-softmax loss, used in regressive decoding):

      知识蒸馏(Knowledge Distillation)

    • Total loss:

      知识蒸馏(Knowledge Distillation)

知识蒸馏(Knowledge Distillation)

 

14、Miscellaneous

——– 知识蒸馏可以与量化结合使用,考虑了中间层Feature Maps之间的关系,可参考:

结合量化的知识蒸馏(Quantization Mimic)_AI Flash-CSDN博客

——– 知识蒸馏与Hint Learning相结合,可以训练精简的Faster-RCNN,可参考:

目标检测网络的知识蒸馏_AI Flash-CSDN博客_目标检测 知识蒸馏

——– 网络结构搜索(NAS)也可以采用蒸馏操作,改善搜索效果,可参考(Cream NAS的Inter-model Distillation):

自蒸馏One-shot NAS——Cream of the Crop_AI Flash-CSDN博客

——– 知识蒸馏在Transformer模型压缩方面,主要采用Self-attention Knowledge Distillation,可参考:

Bert/Transformer模型压缩与优化加速_AI Flash-CSDN博客_transformer模型加速

Transformer端侧模型压缩——Mobile Transformer_AI Flash-CSDN博客

——– 模型压缩方面,更为详细的讨论,请参考:

深度学习模型压缩与优化加速(Model Compression and Acceleration Overview)_AI Flash-CSDN博客_深度学习模型压缩

Published by

风君子

独自遨游何稽首 揭天掀地慰生平

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注