基于知识蒸馏的模型训练方法、图像处理方法及装置与流程

专利2022-06-29  67


本公开一般地涉及图像识别领域,具体涉及一种基于知识蒸馏的模型训练方法、基于知识蒸馏的模型训练装置、图像处理方法、图像处理装置、电子设备和计算机可读存储介质。



背景技术:

随着人工智能识别的发展,普遍采用模型进行数据处理、图像识别等,不断提高识别精度和识别范围的同时,神经网络也越来越庞大,计算耗时、参数多并且所需存储容量巨大。因此很难将其应用于移动端,特别是硬件设备较差的移动端。

知识蒸馏是一种模型压缩方法,在教师-学生框架中,将复杂、学习能力强的教师模型学到的特征表示“知识”蒸馏出来,传递给参数量小、学习能力弱的学生模型。简单的说就是用新的小模型去学习大模型的预测结果,复杂模型或者组合模型的中“知识”通过合适的方式迁移到一个相对简单模型之中,进而方便模型推广部署。

目前一些知识蒸馏技术主要方法主要针对蒸馏的位置(如最后一层特征输出位置、特征图(featuremap)输出位置,神经网络softmax前的输出logit的位置)等和蒸馏的度量方式两个方向。但是传统方法对所有训练样本一视同仁,让学生模型尽可能模仿教师模型,但学生模型毕竟容量有限,不可能完美学到教师模型的所有知识,盲目模仿往往不能取得最优性能。



技术实现要素:

为了解决现有技术中存在的上述问题,本公开的第一方面提供一种基于知识蒸馏的模型训练方法,其中,应用于学生模型,方法包括:根据蒸馏位置,设置与蒸馏位置的第一输出层相同的第二输出层;获取训练集,训练集包括多个训练数据;基于训练数据,得到第一输出层输出的第一数据、以及第二输出层输出的第二数据;获取教师模型基于训练数据,在与蒸馏位置对应的教师层输出的监督数据,其中教师模型为已完成训练、且与学生模型完成相同任务的复杂模型;基于监督数据与第一数据的差距、以及第二数据,根据蒸馏损失函数得到蒸馏损失值;基于蒸馏损失值,更新学生模型的参数。

在一例中,蒸馏损失函数设置为,差距与蒸馏损失值正相关;且当第二数据提高时,蒸馏损失值降低。

在一例中,蒸馏损失函数为:

其中,为监督数据,为第一数据,为第二数据。

在一例中,训练集还包括:与训练数据一一对应的标准标注数据;方法还包括:基于训练数据,得到学生模型输出的学生输出数据;以及,基于标准标注数据以及学生输出数据,根据任务损失函数得到任务损失值;基于蒸馏损失值,更新学生模型的参数,包括:基于任务损失值与蒸馏损失值的损失总值,更新学生模型的参数。

在一例中,方法还包括:经过多次迭代使得损失总值低于训练阈值后,删除第二输出层,得到完成训练的学生模型。

在一例中,蒸馏位置包括以下一个或多个:学生模型中任一特征提取层,学生模型的全连接层。

本公开的第二方面提供一种图像处理方法,方法包括:获取图像;通过模型提取图像的图像特征,得到图像识别结果,其中,模型为通过第一方面的基于知识蒸馏的模型训练方法得到的学生模型。

本公开的第三方面提供一种基于知识蒸馏的模型训练装置,应用于学生模型,装置包括:模型构建模块,用于根据蒸馏位置,设置与蒸馏位置的第一输出层相同的第二输出层;第一获取模块,用于获取训练集,训练集包括多个训练数据;数据处理模块,用于基于训练数据,得到第一输出层输出的第一数据、以及第二输出层输出的第二数据;第二获取模块,用于获取教师模型基于训练数据,在与蒸馏位置对应的教师层输出的监督数据,其中教师模型为已完成训练、且与学生模型完成相同任务的复杂模型;损失计算模块,用于基于监督数据与第一数据的差距、以及第二数据,根据蒸馏损失函数得到蒸馏损失值;参数调整模块,用于基于蒸馏损失值,更新学生模型的参数。

本公开的第四方面提供一种图像处理装置,装置包括:图像获取模块,用于获取图像;图像识别模块,用于通过模型提取图像的图像特征,得到图像识别结果,其中,模型为通过如第一方面的基于知识蒸馏的模型训练方法得到的学生模型。

本公开的第五方面提供一种电子设备,包括:存储器,用于存储指令;以及处理器,用于调用存储器存储的指令执行第一方面的基于知识蒸馏的模型训练方法或第二方面的图像处理方法。

本公开的第六方面提供一种计算机可读存储介质,其中存储有指令,指令被处理器执行时,执行如第一方面的基于知识蒸馏的模型训练方法或第二方面的图像处理方法。

本公开提供的基于知识蒸馏的模型训练方法及装置通过根据蒸馏位置增加输出层并通过相应的损失函数,使得知识蒸馏中教师模型更加偏重将简单数据的知识传递给学生模型,即自适应知识迁移,减少传递脏数据、过难样本数据的知识传递给学生模型,能够适应任何学生模型并且可以根据需要对不同位置进行知识迁移,保证了在模型结构简单、参数少的学生模型的训练效果,以及学生模型识别结果的准确性和可靠性。

附图说明

通过参考附图阅读下文的详细描述,本公开实施方式的上述以及其他目的、特征和优点将变得易于理解。在附图中,以示例性而非限制性的方式示出了本公开的若干实施方式,其中:

图1示出了根据本公开一实施例基于知识蒸馏的模型训练方法的流程示意图;

图2示出了根据本公开另一实施例基于知识蒸馏的模型训练方法的流程示意图;

图3示出了根据本公开一实施例图像处理方法的流程示意图;

图4示出了根据本公开一实施例基于知识蒸馏的模型训练装置的示意图;

图5示出了根据本公开一实施例图像处理装置的示意图;

图6是本公开实施例提供的一种电子设备示意图。

在附图中,相同或对应的标号表示相同或对应的部分。

具体实施方式

下面将参考若干示例性实施方式来描述本公开的原理和精神。应当理解,给出这些实施方式仅仅是为了使本领域技术人员能够更好地理解进而实现本公开,而并非以任何方式限制本公开的范围。

需要注意,虽然本文中使用“第一”、“第二”等表述来描述本公开的实施方式的不同模块、步骤和数据等,但是“第一”、“第二”等表述仅是为了在不同的模块、步骤和数据等之间进行区分,而并不表示特定的顺序或者重要程度。实际上,“第一”、“第二”等表述完全可以互换使用。

为了使得通过知识蒸馏训练学生网络的过程更加高效,将更可靠更有用的知识从复杂的教师模型传递给模型更简化的学生模型,本公开实施例提供了一种基于知识蒸馏的模型训练方法10,应用于教师-学生模型的知识蒸馏框架中的学生模型,如图1所示,基于知识蒸馏的模型训练方法10可以包括步骤s11-步骤s16,下面对上述步骤进行详细说明:

步骤s11,根据蒸馏位置,设置与蒸馏位置的第一输出层相同的第二输出层。

根据需要进行知识蒸馏的位置,对学生模型的进行简单的改造,根据该位置原有的第一输出层,增加与第一输出层并列设置的第二输出层。第二输出层与第一输出层的位置、结构均相同,其参数可以不同,并在之后的训练过程中,两者参数独立调整。

在一些实施例中,蒸馏位置包括以下一个或多个:学生模型中任一特征提取层,学生模型的全连接层。可以根据实际需要,选择蒸馏的位置,并且,在一次训练过程中,可以选择多个蒸馏位置,对每个蒸馏位置均进行前述结构改造即可。

步骤s12,获取训练集,训练集包括多个训练数据。

获取用于训练的多个训练数据,在多次迭代训练中,训练数据用于输入模型,并通过监督数据对结果计算损失,从而更新模型参数。

步骤s13,基于训练数据,得到第一输出层输出的第一数据、以及第二输出层输出的第二数据。

将训练数据输入学生模型后,通过向后传播,通过第一输出层可以得到第一数据,在同样位置,与第一输出层并列的第二输出层得到第二数据。根据蒸馏位置不同,即第一输出层和第二输出层的不同,第一数据、第二数据可以是特征数据,特征图,神经网络softmax前输出的logit等。其中第一数据和第二数据的值可能不同,但格式和维度相同,如都是特征向量表达。

步骤s14,获取教师模型基于训练数据,在与蒸馏位置对应的教师层输出的监督数据,其中教师模型为已完成训练、且与学生模型完成相同任务的复杂模型。

教师模型是更复杂的模型,可以提现在其层数更多,结构更复杂,参数更多等,因此教师模型也具有非常好的性能和泛化能力,同时也由于需要更大的存储空间以及更多计算力的支持,难于部署在一些终端设备。

本实施例中的教师模型是已经完成训练的模型,并且与学生模型用于完成相同的任务,如均用于图像识别。而且两者结构基本相同,基于学生模型的蒸馏位置,教师模型中也有相应的教师层输出相应的监督数据,该监督数据与第一数据和第二数据具有相同的格式和维度。

步骤s15,基于监督数据与第一数据的差距、以及第二数据,根据蒸馏损失函数得到蒸馏损失值。

在一些传统技术中,仅根据监督数据和第一数据的差距,对学生模型的参数进行调整。而本公开实施例中,增加了第二数据,蒸馏损失函数包括监督数据与第一数据的差距,以及第二数据,以得到蒸馏损失值,学生模型根据蒸馏损失值能够更新模型的参数。

在一实施例中,蒸馏损失函数设置为差距与蒸馏损失值正相关;且当第二数据提高时,蒸馏损失值降低。监督数据与第一数据的差距越大,蒸馏损失值越大,则学生模型需要调整参数的幅度也越大,说明对于该训练数据学生模型蒸馏位置的输出的数据不能很好的表达其特征。但另一方面,监督数据与第一数据的差距过大,也说明该训练数据可能是脏数据,或者是过难的训练数据,对于脏数据应尽量避免知识迁移,而过难的训练数据在实际应用环境中可能是不常见的数据,对于简化的学生模型来说,学习过难的训练数据意义不大。因此,蒸馏损失函数中,通过第二数据,在监督数据与第一数据的差距过大的情况下,可以调整第二输出层的参数,提高第二数据,降低蒸馏损失函数的蒸馏损失值。训练的过程即为使损失值降低的过程,因此,学生模型在根据蒸馏损失值进行参数更新时,能够通过更新第二输出层的参数使第二数据增大,从而达到降低蒸馏损失值的目的,通过该种方式,即相应的降低了调整其他参数权重、尤其相对降低了调整第一输出层参数的权重。通过上述方式,减少了训练时脏数据或过难数据的知识迁移的危害,相应也就提高了干净数据、较难数据的训练效果,使得训练更加高效。

在另一些实施例中,蒸馏损失函数ldistill可以是

其中,为监督数据,为第一数据,为第二数据,d代表维度,n代表训练数据的批次数量。监督数据、第一数据和第二数据是相应蒸馏位置的数据,数据形式一样,如可以是输出的d维度的特征向量,而其中监督数据为教师模型输出的用于传递知识的数据,第一数据是学生模型中蒸馏位置的第一输出层输出的能够表示出特征的数据,而第二数据则是在蒸馏位置设置的第二输出层输出的能够表示该训练数据的置信度的数据,起到了调节权重的作用,在训练数据为脏数据等情况下,能够降低对模型其他参数的影响。

在损失值超过阈值的情况下,需要更新学生模型的参数以降低损失值,学生模型可以通过更新参数,使得第一数据靠近监督数据,从而降低蒸馏损失值,在以往技术中,仅能通过这种方式降低损失值,无法避免或降低脏数据或过难数据对学生模型的影响。而本实施例中,根据公式中部分可以看出,除了可以通过更新学生模型参数使得第一数据靠近监督数据降低损失值之外,在监督数据与第一数据的差距过大时,学生模型还可以通过更新参数,尤其是第二输出层的参数,以提高第二数据也能够降低蒸馏损失值,在这种情况下,减少对学生模型其他参数的调整,也就减少了脏数据或过难数据对学生模型的不良影响。同时,为避免学生模型过于偏重调整第二输出层的参数,仅仅一味提高第二数据的值以降低蒸馏损失值,导致无法从教师模型中迁移更多的知识,上述公式中同时还设置一制约因子在第二数据过低时,制约因子的值会增加,因此避免了学生模型仅盲目调整第二输出层参数以降低第二数据的情况。

步骤s16,基于蒸馏损失值,更新学生模型的参数。

更新模型的参数,包括学生模型的全部参数,其中包括通过调整第一输出层的参数和调整第二输出层的参数,而调整第一输出层的参数对第一数据的影响最强,调整第二输出层的参数对第一数据影响最强。因此,根据蒸馏损失值能够更好的训练更新蒸馏位置的参数,并且通过前述任一实施例蒸馏损失函数能够减少脏数据或过难数据对学生模型造成的不良影响,使得训练更加高效,训练完毕的学生模型的输出结果更加准确。

在一实施例中,如图2所示,基于知识蒸馏的模型训练方法10,其中,训练集还包括:与训练数据一一对应的标准标注数据;基于知识蒸馏的模型训练方法10还包括:步骤s17,基于训练数据,得到学生模型输出的学生输出数据;以及,步骤s18,基于标准标注数据以及学生输出数据,根据任务损失函数得到任务损失值;同时,步骤s16包括:基于任务损失值与蒸馏损失值的损失总值,更新学生模型的参数。

为完整的更好的训练学生模型,对学生模型进行任务相关的训练,即根据模型实际的要完成的任务,通过步骤s17将训练数据输入学生模型,并通过全部学生模型中的结构输出最终的学生输出数据,即识别结果、聚类结果等,再通过步骤s18,基于学生输出数据和训练数据的标准标注数据的比对,根据任务损失函数得到任务损失值;之后在步骤s16中,根据任务损失值与蒸馏损失值的损失总值,更新学生模型的参数。从而对学生模型进行全面准确的训练,避免局部蒸馏对模型整体结果的影响。

在一实施例中,基于知识蒸馏的模型训练方法10,还包括经过多次迭代使得损失总值低于训练阈值后,删除第二输出层,得到完成训练的学生模型。当学生模型的输出足够收敛,即损失总值低于一个训练阈值后,可以认为学生模型的参数更新达到要求,此时,可以将事先设置的第二输出层删除,得到完成训练的学生模型。本公开提供的实施例中,第二输出层仅用于在基于知识蒸馏的模型训练过程中,提供第二数据,达到前述实施例描述的作用和效果,而在参数更新达到要求后,第二输出层的存在意义也消失了,将其删除能够恢复恢复学生模型的原有结构,降低学生模型的存储空间,减少无意义的计算。在同时蒸馏多个蒸馏位置时,尤为显著。通过本实例的方式能够恢复学生模型的原有结构,也提高了本公开实施例的泛用性。

基于同一发明构思,图3示出了本公开实施例提供的一种图像处理方法20,包括:步骤s21,获取图像;步骤s22,通过模型提取图像的图像特征,得到图像识别结果,其中,模型为通过前述任一实施例的基于知识蒸馏的模型训练方法10得到的学生模型。在一些场景中,学生模型用于图像识别,基于知识蒸馏的模型训练方法10得到的学生模型训练更加高效,结构精简,运算速度更快,可以在终端设备中使用,并且也保证了图像处理结果的准确性。

基于同一发明构思,图4示出了本公开实施例提供的一种基于知识蒸馏的模型训练装置100,应用于学生模型,如图4所示,基于知识蒸馏的模型训练装置100包括:模型构建模块110,用于根据蒸馏位置,设置与蒸馏位置的第一输出层相同的第二输出层;第一获取模块120,用于获取训练集,训练集包括多个训练数据;数据处理模块130,用于基于训练数据,得到第一输出层输出的第一数据、以及第二输出层输出的第二数据;第二获取模块140,用于获取教师模型基于训练数据,在与蒸馏位置对应的教师层输出的监督数据,其中教师模型为已完成训练、且与学生模型完成相同任务的复杂模型;损失计算模块150,用于基于监督数据与第一数据的差距、以及第二数据,根据蒸馏损失函数得到蒸馏损失值;参数调整模块160,用于基于蒸馏损失值,更新学生模型的参数。

在一例中,蒸馏损失函数设置为,差距与蒸馏损失值正相关;且当第二数据提高时,蒸馏损失值降低。

在一例中,蒸馏损失函数为:

其中,为监督数据,为第一数据,为第二数据。

在一例中,训练集还包括:与训练数据一一对应的标准标注数据;数据处理模块130还用于:基于训练数据,得到学生模型输出的学生输出数据;损失计算模块150还用于:基于标准标注数据以及学生输出数据,根据任务损失函数得到任务损失值;参数调整模块160还用于:基于任务损失值与蒸馏损失值的损失总值,更新学生模型的参数。

在一例中,模型构建模块110还用于:经过多次迭代使得损失总值低于训练阈值后,删除第二输出层,得到完成训练的学生模型。

在一例中,蒸馏位置包括以下一个或多个:学生模型中任一特征提取层,学生模型的全连接层。

关于上述实施例中的基于知识蒸馏的模型训练装置100,其中各个模块执行操作的具体方式已经在有关该方法的实施例中进行了详细描述,此处将不做详细阐述说明。

基于同一发明构思,图5示出了本公开实施例提供的一种图像处理装置200,如图5所示,图像处理装置200包括:图像获取模块210,用于获取图像;图像识别模块220,用于通过模型提取图像的图像特征,得到图像识别结果,其中,模型为通过如第一方面的基于知识蒸馏的模型训练方法10得到的学生模型。

关于上述实施例中的图像处理装置200,其中各个模块执行操作的具体方式已经在有关该方法的实施例中进行了详细描述,此处将不做详细阐述说明。

如图6所示,本公开的一个实施方式提供了一种电子设备300。其中,该电子设备300包括存储器301、处理器302、输入/输出(input/output,i/o)接口303。其中,存储器301,用于存储指令。处理器302,用于调用存储器301存储的指令执行本公开实施例的神经网络压缩方法或图像处理方法。其中,处理器302分别与存储器301、i/o接口303连接,例如可通过总线系统和/或其他形式的连接机构(未示出)进行连接。存储器301可用于存储程序和数据,包括本公开实施例中涉及的神经网络压缩方法或图像处理方法的程序,处理器302通过运行存储在存储器301的程序从而执行电子设备300的各种功能应用以及数据处理。

本公开实施例中处理器302可以采用数字信号处理器(digitalsignalprocessing,dsp)、现场可编程门阵列(field-programmablegatearray,fpga)、可编程逻辑阵列(programmablelogicarray,pla)中的至少一种硬件形式来实现,所述处理器302可以是中央处理单元(centralprocessingunit,cpu)或者具有数据处理能力和/或指令执行能力的其他形式的处理单元中的一种或几种的组合。

本公开实施例中的存储器301可以包括一个或多个计算机程序产品,所述计算机程序产品可以包括各种形式的计算机可读存储介质,例如易失性存储器和/或非易失性存储器。所述易失性存储器例如可以包括随机存取存储器(randomaccessmemory,ram)和/或高速缓冲存储器(cache)等。所述非易失性存储器例如可以包括只读存储器(read-onlymemory,rom)、快闪存储器(flashmemory)、硬盘(harddiskdrive,hdd)或固态硬盘(solid-statedrive,ssd)等。

本公开实施例中,i/o接口303可用于接收输入的指令(例如数字或字符信息,以及产生与电子设备300的用户设置以及功能控制有关的键信号输入等),也可向外部输出各种信息(例如,图像或声音等)。本公开实施例中i/o接口303可包括物理键盘、功能按键(比如音量控制按键、开关按键等)、鼠标、操作杆、轨迹球、麦克风、扬声器、和触控面板等中的一个或多个。

可以理解的是,本公开实施例中尽管在附图中以特定的顺序描述操作,但是不应将其理解为要求按照所示的特定顺序或是串行顺序来执行这些操作,或是要求执行全部所示的操作以得到期望的结果。在特定环境中,多任务和并行处理可能是有利的。

本公开实施例涉及的方法和装置能够利用标准编程技术来完成,利用基于规则的逻辑或者其他逻辑来实现各种方法步骤。还应当注意的是,此处以及权利要求书中使用的词语“装置”和“模块”意在包括使用一行或者多行软件代码的实现和/或硬件实现和/或用于接收输入的设备。

此处描述的任何步骤、操作或程序可以使用单独的或与其他设备组合的一个或多个硬件或软件模块来执行或实现。在一个实施方式中,软件模块使用包括包含计算机程序代码的计算机可读介质的计算机程序产品实现,其能够由计算机处理器执行用于执行任何或全部的所描述的步骤、操作或程序。

出于示例和描述的目的,已经给出了本公开实施的前述说明。前述说明并非是穷举性的也并非要将本公开限制到所公开的确切形式,根据上述教导还可能存在各种变形和修改,或者是可能从本公开的实践中得到各种变形和修改。选择和描述这些实施例是为了说明本公开的原理及其实际应用,以使得本领域的技术人员能够以适合于构思的特定用途来以各种实施方式和各种修改而利用本公开。


技术特征:

1.一种基于知识蒸馏的模型训练方法,其中,应用于学生模型,所述方法包括:

根据蒸馏位置,设置与所述蒸馏位置的第一输出层相同的第二输出层;

获取训练集,所述训练集包括多个训练数据;

基于所述训练数据,得到所述第一输出层输出的第一数据、以及所述第二输出层输出的第二数据;

获取教师模型基于所述训练数据,在与所述蒸馏位置对应的教师层输出的监督数据,其中所述教师模型为已完成训练、且与所述学生模型完成相同任务的复杂模型;

基于所述监督数据与所述第一数据的差距、以及所述第二数据,根据蒸馏损失函数得到蒸馏损失值;

基于所述蒸馏损失值,更新所述学生模型的参数。

2.根据权利要求1所述的基于知识蒸馏的模型训练方法,其中,所述蒸馏损失函数设置为所述差距与所述蒸馏损失值正相关;且当所述第二数据提高时,所述蒸馏损失值降低。

3.根据权利要求2所述的基于知识蒸馏的模型训练方法,其中,所述蒸馏损失函数为:

其中,为所述监督数据,为所述第一数据,为所述第二数据。

4.根据权利要求1-3任一项所述的基于知识蒸馏的模型训练方法,其中,

所述训练集还包括:与所述训练数据一一对应的标准标注数据;

所述方法还包括:基于所述训练数据,得到所述学生模型输出的学生输出数据;以及,基于所述标准标注数据以及所述学生输出数据,根据任务损失函数得到任务损失值;

所述基于所述蒸馏损失,更新所述学生模型的参数,包括:基于所述任务损失值与所述蒸馏损失值的损失总值,更新所述学生模型的参数。

5.根据权利要求4所述的基于知识蒸馏的模型训练方法,其中,所述方法还包括:经过多次迭代使得所述损失总值低于训练阈值后,删除所述第二输出层,得到完成训练的学生模型。

6.根据权利要求1所述的基于知识蒸馏的模型训练方法,其中,所述蒸馏位置包括以下一个或多个:所述学生模型中任一特征提取层,所述学生模型的全连接层。

7.一种图像处理方法,其中,所述方法包括:

获取图像;

通过模型提取所述图像的图像特征,得到图像识别结果,其中,所述模型为通过如权利要求1-6任一项所述的基于知识蒸馏的模型训练方法得到的学生模型。

8.一种基于知识蒸馏的模型训练装置,其中,应用于学生模型,所述装置包括:

模型构建模块,用于根据蒸馏位置,设置与所述蒸馏位置的第一输出层相同的第二输出层;

第一获取模块,用于获取训练集,所述训练集包括多个训练数据;

数据处理模块,用于基于所述训练数据,得到所述第一输出层输出的第一数据、以及所述第二输出层输出的第二数据;

第二获取模块,用于获取教师模型基于所述训练数据,在与所述蒸馏位置对应的教师层输出的监督数据,其中所述教师模型为已完成训练、且与所述学生模型完成相同任务的复杂模型;

损失计算模块,用于基于所述监督数据与所述第一数据的差距、以及所述第二数据,根据蒸馏损失函数得到蒸馏损失值;

参数调整模块,用于基于所述蒸馏损失值,更新所述学生模型的参数。

9.一种图像处理装置,其中,所述装置包括:

图像获取模块,用于获取图像;

图像识别模块,用于通过模型提取所述图像的图像特征,得到图像识别结果,其中,所述模型为通过如权利要求1-6任一项所述的基于知识蒸馏的模型训练方法得到的学生模型。

10.一种电子设备,其中,所述电子设备包括:

存储器,用于存储指令;以及

处理器,用于调用所述存储器存储的指令执行如权利要求1-6中任一项所述的基于知识蒸馏的模型训练方法或如权利要求7所述的图像处理方法。

11.一种计算机可读存储介质,其中存储有指令,所述指令被处理器执行时,执行如权利要求1-6中任一项所述的基于知识蒸馏的模型训练方法或如权利要求7所述的图像处理方法。

技术总结
本公开提供了一种基于知识蒸馏的模型训练方法,应用于学生模型,包括:根据蒸馏位置,设置与蒸馏位置的第一输出层相同的第二输出层;获取训练集,训练集包括多个训练数据;基于训练数据,得到第一输出层输出的第一数据、以及第二输出层输出的第二数据;获取教师模型基于训练数据,在与蒸馏位置对应的教师层输出的监督数据,其中教师模型为已完成训练、且与学生模型完成相同任务的复杂模型;基于监督数据与第一数据的差距、以及第二数据,根据蒸馏损失函数得到蒸馏损失值;基于蒸馏损失值,更新学生模型的参数。通过公开实施例使得知识蒸馏中教师模型更加偏重将简单数据的知识传递给学生模型,提高了知识蒸馏的训练效率,保证了学生模型准确性。

技术研发人员:张有才;戴雨辰;常杰;危夷晨
受保护的技术使用者:北京迈格威科技有限公司
技术研发日:2019.12.19
技术公布日:2020.06.05

转载请注明原文地址: https://bbs.8miu.com/read-52582.html

最新回复(0)