为什么匹配logits是模型蒸馏的一个特例?

蒸馏技术的核心是将一个复杂的大模型(称为原始模型, model)中学到的知识迁移到一个较小的模型中,这个较小的模型被称为蒸馏模型( model)。原始模型通常规模庞大,参数众多,虽然精度高,但不易部署和运行效率低;而蒸馏模型则更小,更高效,适合实际应用。

知识迁移的关键在于让蒸馏模型不仅仅学习训练样本上的硬标签(hard ,即真实类别标签),还能从原始模型提供的丰富细节信息中获得泛化能力。这些细节被称为「知识」,反映了原始模型如何区分各类数据的微妙差别。

是什么?

在深度学习中,模型的输出通常经过一个叫做的函数,将模型对每个类别的非归一化分数(称为)转化为概率分布。是指模型在之前的输出向量元素,它们可以是任意实数,体现模型对每个类别的相对倾向强弱。

例如,一个三分类任务的可能是

2.5, 0.3, -1.2

,会将它们转成概率,如

0.78, 0.18, 0.04

,其中最高概率对应预测的类别。但本身携带更多关于类别之间相对关系的信息,这在之后的知识蒸馏中至关重要。

迁移集和软目标分布

在蒸馏过程中,原始模型会基于一个特定的数据集(称为迁移集, set),生成对每个输入的预测概率分布,这就是软目标分布(soft )。和传统的硬标签不同,软目标不仅仅告诉蒸馏模型哪个类别是真的,还反映了模型对所有类别的相对置信度,携带复杂的类间相似度信息。

这些软目标往往比硬标签信息量大,能指导小模型学到更丰富的泛化能力,有时即使迁移集中缺乏某些类别的实例,小模型依然可以对这类数据做出合理的预测。

温度参数的作用

为了让软目标信息更丰富,蒸馏采用了一个技巧:将函数中的温度参数T调高。温度控制输出概率的平滑程度:

在蒸馏训练中,既对原始模型的输出使用高温度生成软目标,又用相同高温度来训练蒸馏模型以匹配这些软目标,训练完成后推理时恢复温度为1。

为什么匹配是蒸馏的一个特例?

通过数学推导可以发现:在温度T非常高的极限下,函数近似变成线性变换,软目标概率趋近于之间的线性关系。特别地,如果将迁移集上的进行中心化处理(每个样本的各自减去均值),蒸馏过程中匹配软目标的交叉熵损失函数等价于最小化两个模型之间的均方误差。

换句话说:

这一特例的优势在于计算简单且直观,但在实际中,使用中等温度的软目标能更好地忽略极端负带来的噪声,提升蒸馏模型的泛化性能。

总的来说,知识蒸馏通过利用软目标分布让蒸馏模型学习原始模型隐含的丰富知识,而温度参数则调控的平滑程度,增强软目标的表达力。在温度极高的情况下,蒸馏过程退化成了简单的匹配,揭示了匹配其实是蒸馏的一个特例。

参考资料

, , Oriol , and Jeff Dean. ” the in a .” arXiv arXiv:1503.02531 (2015).

【往期回顾】

模型蒸馏_Logits迁移_知识蒸馏
© 版权声明

相关文章

暂无评论

none
暂无评论...