本发明一般地涉及图像识别技术领域,特别是涉及一种神经网络训练方法、图像处理方法及装置,以及电子设备和计算机可读存储介质。
背景技术:
目前神经网络训练过程中,通过大量训练样本的多次迭代训练,进行度量学习是为了学习一个在图片上的距离函数,使得同一类别的图片将被映射到相邻的位置,具有不同类别的图片将被映射到互相远离的地方。在深度学习的方法中,这个距离函数即是网络的特征嵌入,现在的有很多基于深度的度量学习损失,例如对比损失(contrastiveloss)、三元组损失(tripletloss)、四元组损失(quadrupletloss)等。这些损失都在一些相关联的样本上计算,都有相同的目标,即鼓励相同类的样本相互接近,不同类的样本相互远离。
对于一些相关度量学习技术来说,由于资源的限制,存在着只能在局部样本上进行优化的问题。具体来说,目前深度网络通常是使用随机梯度下降算法优化的,因此梯度只来源于一小批样本。这样对于相关的度量学习损失来说,只看到了局部的数据分布,无法优化到全局的最优。如果无限制的对样本进行采样,不仅引入了大的计算量,还将同时引入大量的无信息样本对,影响模型的收敛速度和性能。
技术实现要素:
为了解决现有技术中存在的上述问题,本发明提供一种神经网络训练方法、图像处理方法及装置,以及电子设备和计算机可读存储介质。
根据本公开实施例的第一方面,提供一种神经网络训练方法,包括:获取总训练集,总训练集包括多个类别的训练数据,其中每个类别包括一个或多个训练数据;针对每一训练轮次,基于总训练集,对每个类别的训练数据进行采样,得到采样后的训练数据组成的子训练集;根据神经网络对子训练集中的每个类别的训练数据进行特征提取,得到特征向量;基于当前神经网络,确定每个类别的中心向量,其中中心向量作为锚点;基于中心向量和特征向量获得损失函数的值,根据损失函数的值调整神经网络的参数。
在一例中,基于当前神经网络,确定每个类别的中心向量,包括:基于当前训练轮次的子训练集中每个类别的训练数据的特征向量,分别对特征向量对应类别的历史中心向量进行更新,得到每个类别的当前训练轮次的中心向量,其中历史中心向量为前一训练轮次的中心向量,对于第一训练轮次的前一训练轮次的中心向量为预设的每个类别的初始中心向量。
在一例中,基于当前训练轮次的子训练集中每个类别的训练数据的特征向量,分别对特征向量对应类别的历史中心向量进行更新,采用以下方式:分别设置子训练集中每个训练数据的特征向量的第一权重,以及历史中心向量的第二权重;基于第一权重和第二权重,对当前训练轮次每个类别的训练数据的特征向量与对应类别的历史中心向量进行加权,获得对应类别当前训练轮次的中心向量。
在一例中,基于当前训练轮次的子训练集中每个类别的训练数据的特征向量,分别对特征向量对应类别的历史中心向量进行更新,采用以下方式:基于每个类别的训练数据的特征向量,确定与每个类别对应的中心损失;基于中心损失,更新对应类别的历史中心向量,得到对应类别的当前训练轮次中心向量。
在一例中,基于中心损失,更新对应类别的历史中心向量,得到对应类别的当前训练轮次中心向量,包括:基于中心损失,获取对应类别的历史中心向量与当前训练轮次的子训练集中对应类别的训练数据的特征向量之间的距离;根据距离,确定对应类别的历史中心向量的更新量;根据更新量更新对应类别的历史中心向量,得到对应类别的当前训练轮次中心向量。
在一例中,基于当前神经网络,确定每个类别的中心向量,包括:基于神经网络的每个类别的分类输出权重,确定对应类别的中心向量。
在一例中,基于中心向量和特征向量获得损失函数的值,根据损失函数的值调整神经网络的参数,包括:基于中心向量,在当前训练轮次的子训练集全部训练数据的特征向量中确定第一向量和第二向量,其中,第一向量为与中心向量类别相同的特征向量,第二向量为与中心向量类别不同的特征向量;基于中心向量、第一向量以及第二向量,通过三元组损失函数得到损失值,并基于损失值调整神经网络的参数。
在一例中,第一向量为与中心向量之间距离最远且与中心向量类别相同的特征向量;第二向量为与中心向量之间距离最近且与中心向量类别不同的特征向量。
根据本公开实施例的第二方面,提供一种图像处理方法,包括:获取图像;通过神经网络进行图像识别,得到图像的分类结果,其中神经网络通过第一方面的神经网络训练方法训练得到。
根据本公开实施例的第三方面,提供一种神经网络训练装置,包括:获取模块,用于获取总训练集,总训练集包括多个类别的训练数据,其中每个类别包括一个或多个训练数据;采样模块,用于针对每一训练轮次,基于总训练集,对每个类别的训练数据进行采样,得到采样后的训练数据组成的子训练集;特征提取模块,用于根据神经网络对子训练集中的每个类别的训练数据进行特征提取,得到特征向量;中心确定模块,用于确定每个类别的当前轮次的中心向量,其中中心向量作为锚点;训练模块,用于基于中心向量和特征向量获得损失函数的值,根据损失函数的值调整神经网络的参数。
根据本公开实施例的第四方面,提供一种图像处理装置,包括:接收模块,用于获取图像;处理模块,用于通过神经网络进行图像识别,得到图像的分类结果,其中神经网络通过第一方面的神经网络训练方法训练得到。
根据本公开实施例的第五方面,提供一种电子设备,其中,电子设备包括:存储器,用于存储指令;以及处理器,用于调用存储器存储的指令执行第一方面的神经网络训练方法或第二方面图像处理方法。
根据本公开实施例的第六方面,提供一种计算机可读存储介质,其中,计算机可读存储介质存储有计算机可执行指令,计算机可执行指令在由处理器执行时,执行第一方面的神经网络训练方法或第二方面图像处理方法。
本公开提供的神经网络训练方法、图像处理方法及装置,以及电子设备和计算机可读存储介质,通过在训练过程中,根据当前的神经网络确定中心向量,使得中心向量能够具有全局信息,从而基于中心向量计算损失并调整神经网络的参数,能够提高训练效率,加快神经网络输出结果的收敛速度,并且能够降低获取训练样本的难度,提高了训练效果。
附图说明
通过参考附图阅读下文的详细描述,本发明实施方式的上述以及其他目的、特征和优点将变得易于理解。在附图中,以示例性而非限制性的方式示出了本发明的若干实施方式,其中:
图1示出了本发明实施例提供的一种神经网络训练方法的流程示意图;
图2示出了本发明实施例提供的一种神经网络训练方法中确定每个类别的中心向量步骤的流程示意图;
图3示出了本发明实施例提供的一种神经网络训练方法中调整神经网络的参数步骤的流程示意图;
图4示出了本发明实施例提供的一种图像处理方法示意图;
图5示出了本发明实施例提供的一种神经网络训练装置示意图;
图6示出了本发明实施例提供的一种图像处理装置示意图;
图7示出了本发明实施例提供的一种电子设备示意图;
在附图中,相同或对应的标号表示相同或对应的部分。
具体实施方式
下面将参考若干示例性实施方式来描述本发明的原理和精神。应当理解,给出这些实施方式仅仅是为了使本领域技术人员能够更好地理解进而实现本发明,而并非以任何方式限制本发明的范围。
需要注意,虽然本文中使用“第一”、“第二”等表述来描述本发明的实施方式的不同模块、步骤和数据等,但是“第一”、“第二”等表述仅是为了在不同的模块、步骤和数据等之间进行区分,而并不表示特定的顺序或者重要程度。实际上,“第一”、“第二”等表述完全可以互换使用。
目前,神经网络应用于各个领域,尤其是分类领域中,通过神经网络能够对目标进行快速的分类识别。对神经网络的训练效率、训练效果的要求也越来越高,需要神经网络的输出结果更快的收敛,并且也对训练样本的质量提出了更高的要求,使得训练样本的获取难度加大。
为解决上述问题,本公开提供的一种神经网络训练方法10,其中神经网络可以是用于分类的神经网络,在一些实施例中,神经网络可以是用于图像识别的神经网络,如卷积神经网络。图1为本公开实施例示出的一种神经网络训练方法10的示意图,如图1所示,该方法包括步骤s11至步骤s15:
步骤s11,获取总训练集,总训练集包括多个类别的训练数据,其中每个类别包括一个或多个训练数据。
获取用于训练的总训练集,包括全部用于训练的训练数据,根据需要进行分类的类别,训练数据也具有相应的类别,每个类别可以包括一个或多个训练数据。
步骤s12,针对每一训练轮次,基于总训练集,对每个类别的训练数据进行采样,得到采样后的训练数据组成的子训练集。
在获取总训练集后,针对每一训练轮次,对每个类别的训练数据分别进行采样,形成当前训练轮次的子训练集,其中采样数量可以是一个或多个。在一些实施例中,可以采用pk采样的方式得到子训练集,及从总训练集中,选出p个类别,然后每个类别随机选择k个训练数据,因此每轮训练轮次的训练数据就是p*k个训练数据组成子训练集。
步骤s13,根据神经网络对子训练集中的全部训练数据进行特征提取,得到特征向量。
将子训练集中的全部训练数据输入神经网络,通过神经网络进行特征提取,得到子训练集中全部训练数据的特征向量,特征向量对应训练数据的类别。
步骤s14,基于当前神经网络,确定每个类别的中心向量,其中中心向量作为锚点。
在计算损失前,对于每个类别,都确定一个与特征向量具有相同维度的中心向量,并且该中心向量具有训练迭代中当前神经网络的全局信息,将每个类别的中心向量作为三元组损失计算中该类别的锚点。
在一实施例中,如图2所示,步骤s14可以包括:步骤s141,基于当前训练轮次的子训练集中每个类别的训练数据的特征向量,分别对特征向量对应类别的历史中心向量进行更新,得到每个类别的当前训练轮次中心向量,其中历史中心向量为前一训练轮次的中心向量,对于第一训练轮次,前一训练轮次的中心向量为预设的每个类别的初始中心向量。
本公开实施例中,神经网络通过多训练轮次调整参数,使得神经网络达到标准以完成训练。每个轮次中,均需要进行步骤s12-步骤s15,即每个训练轮次中均需要对中心向量进行在线更新,从而将当前全局信息加入到中心向量中,使得中心向量更加稳定可靠。
上述实施例中可以在训练开始时,设置一个初始中心向量,初始中心向量的值可以为零,在训练开始后,第一训练轮次通过当前神经网络输出的特征向量,针对对应类别的初始中心向量进行更新,而后每个训练轮次中,均通过当前神经网络输出的特征向量针对对应类别的前一训练轮次的中心向量进行更新。通过上述在线更新的方式,能够将当前全局信息加入到中心向量中,从而能够基于一个更加准确和稳定的特征作为基础进行计算损失。
在又一实施例中,步骤s141,具体可采用以下方式对中心向量进行在线更新:分别设置子训练集中每个训练数据的特征向量的第一权重,以及历史中心向量的第二权重;基于第一权重和第二权重,对当前训练轮次每个类别的训练数据的特征向量与对应类别的历史中心向量进行加权,获得对应类别当前训练轮次的中心向量。本实施例中,通过加权和的方式更新中心向量,具体用公式1表达如下:
其中,vt为当前轮次中心向量,vt-1为历史中心向量(当t=1时,vt-1为初始中心向量),y为第二权重,1-y为第一权重,xi为特征向量。
在另一实施例中,步骤s141,具体可采用以下方式对中心向量进行在线更新:基于每个类别的训练数据的特征向量,确定与每个类别对应的中心损失;基于中心损失,更新对应类别的历史中心向量,得到对应类别的当前训练轮次中心向量。本实施例中,通过使用中心损失(centerloss)更新中心向量。其中中心损失lc可以通过公式2表达如下:
其中xi为特征向量,
在一具体实施例中,基于中心损失,可以通过求导的方式,获取对应类别的历史中心向量与当前训练轮次的子训练集中对应类别的训练数据的特征向量之间的距离;可以通过公式3获取该距离:
根据距离,确定对应类别的历史中心向量的更新量;具体可以用公式4获取该更新量δcj:
其中δ为条件函数,满足条件则为1,不满足条件则为0,δ(yi=j)保证了类别的对应。
根据更新量更新对应类别的历史中心向量,得到对应类别的当前训练轮次中心向量,即可根据全局信息动态更新中心向量。
在一实施例中,步骤s14也可以包括:基于神经网络的每个类别的分类输出权重,确定对应类别的中心向量。本实施例中,可以将神经网络中用于对特征向量进行分类打分的全连接层(fclayer)的分类输出权重作为中心向量的值,具体而言,全连接层对神经网络提取的特征向量的每一维度设置有分类输出权重,特征向量和各个类别的分类输出权重相乘得到每个类别的相似得分,得分越大越可能属于这一类别。对于cosface等softmaxloss的变体,特征向量和分类输出权重都会做归一化,模长都为1,因此一个类别对应的分类输出权重序列即可作为该类别的中心向量。并且分类输出权重随着神经网络训练的每轮训练轮次进行更新,因此基于分类输出权重确定的当前训练轮次的中心向量,能够基于全局信息在线更新中心向量。
步骤s15,基于中心向量和特征向量获得损失函数的值,根据损失函数的值调整神经网络的参数。
由于中心向量具有全局信息,因此基于中心向量、特征向量通过损失函数得到损失函数的值调整神经网络的参数能够更准确,使得神经网络的输出结果能够更快的收敛,提高了训练效率。
并且,本公开实施例中,在每轮训练轮次更新神经网络参数的过程中,中心向量作为常量,不参与神经网络的参数的更新,从而保证了其稳定性。
在一实施例中,如图3所示,步骤s15可以包括:步骤s151,基于中心向量,在当前训练轮次的子训练集全部训练数据的特征向量中确定第一向量和第二向量,其中,第一向量为与中心向量类别相同的特征向量,第二向量为与中心向量类别不同的特征向量;步骤s152,基于中心向量、第一向量以及第二向量,通过三元组损失函数得到损失值,并基于损失值调整神经网络的参数。本实施例中,采用三元组损失(tripletloss)计算损失,其中,三元组损失函数是通过标准数据,即锚点(anchor)、同类数据(positive)、异类数据(negative)三元,将anchor和positive进行比对,也将anchor和negative进行比对,基于函数计算损失调整参数后使得anchor和positive之间的距离最小,并且使anchor和negative之间距离最大,其中,positive为和anchor属于同一类别的训练数据,而negative则为和anchor不同类别的训练数据。且,在negative与anchor两个类别之间的距离(相似度)更接近时,训练效率会更高。
本公开实施例中,将中心向量作为标准数据,即锚点(anchor),第一向量即为同类数据(positive),第二向量即为异类数据(negative)。可通过公式5表达:
其中ltri为损失值,xa为中心向量,xp为第一向量,xn为第二向量,α为间隔值(margin)
在又一实施例中,第一向量可以为与中心向量之间距离最远且与中心向量类别相同的特征向量,第二向量可以为与中心向量之间距离最近且与中心向量类别不同的特征向量。通过本实施例的方式,可以对训练样本的进行难例挖掘,每次通过难样本进行训练,使得收敛速度更快,并且能够提高对训练样本利用效果,一定程度上降低了训练样本获取的质量要求。
基于同样的发明构思,本公开还提供一种图像处理方法20,如图4所示,图像处理方法20包括:
步骤s21,获取图像。用于获取待测图像。
步骤s22,通过神经网络进行图像识别,得到图像的分类结果,其中神经网络通过前述任一实施例的神经网络训练方法10训练得到。
通过神经网络训练方法10得到的神经网络,能够更加准确的进行图像识别。
基于相同的发明构思,本公开实施例提供一种神经网络训练装置100,如图5所示,神经网络训练装置100包括:获取模块110,用于获取总训练集,总训练集包括多个类别的训练数据,其中每个类别包括一个或多个训练数据;采样模块120,用于针对每一训练轮次,基于总训练集,对每个类别采样的训练数据进行采样,得到采样后的训练数据组成的子训练集;特征提取模块130,用于根据神经网络对子训练集中的每个类别的训练数据进行特征提取,得到特征向量;中心确定模块140,用于确定每个类别的当前轮次的中心向量,其中中心向量作为锚点;训练模块150,用于基于中心向量和特征向量获得损失函数的值,根据损失函数的值调整神经网络的参数。
在一实施例中,中心确定模块140用于:基于当前训练轮次的子训练集中每个类别的特征向量,分别对特征向量对应类别的历史中心向量进行更新,得到每个类别的当前训练轮次的中心向量,其中历史中心向量为前一轮次的中心向量,对于第一训练轮次,前一训练轮次的中心向量为预设的每个类别的初始中心向量。
在一实施例中,中心确定模块140采用以下方式对特征向量对应类别的历史中心向量进行更新:分别设置子训练集中每个训练数据的特征向量的第一权重,以及历史中心向量的第二权重;基于第一权重和第二权重,对当前训练轮次每个类别的训练数据的特征向量与对应类别的历史中心向量进行加权,获得对应类别当前训练轮次的中心向量。
在一实施例中,中心确定模块140采用以下方式对特征向量对应类别的历史中心向量进行更新:基于每个类别的训练数据的特征向量,确定与每个类别对应的中心损失;基于中心损失,更新对应类别的历史中心向量,得到对应类别的当前训练轮次中心向量。
在一实施例中,中心确定模块140还用于:基于中心损失,获取对应类别的历史中心向量与当前训练轮次的子训练集中对应类别的训练数据的特征向量之间的距离;根据距离,确定对应类别的历史中心向量的更新量;根据更新量更新对应类别的历史中心向量,得到对应类别的当前训练轮次中心向量。
在一实施例中,中心确定模块140用于:基于神经网络的每个类别的分类输出权重,确定对应类别的中心向量。
在一实施例中,训练模块150用于:基于中心向量,在当前训练轮次的子训练集全部训练数据的特征向量中确定第一向量和第二向量,其中,第一向量为与中心向量类别相同的特征向量,第二向量为与中心向量类别不同的特征向量;基于中心向量、第一向量以及第二向量,通过三元组损失函数得到损失值,并基于损失值调整神经网络的参数。
在一实施例中,第一向量为与中心向量之间距离最远且与中心向量类别相同的特征向量;第二向量为与中心向量之间距离最近且与中心向量类别不同的特征向量。
关于上述实施例中的神经网络训练装置100,其中各个模块执行操作的具体方式已经在有关该方法的实施例中进行了详细描述,此处将不做详细阐述说明。
基于相同的发明构思,本公开实施例提供一种图像处理装置200,如图6所示,图像处理装置200包括:接收模块210,用于获取图像;处理模块220,用于通过神经网络进行图像识别,得到图像的分类结果,其中神经网络通过前述任一实施例的神经网络训练方法10训练得到。
关于上述实施例中的图像处理装置200,其中各个模块执行操作的具体方式已经在有关该方法的实施例中进行了详细描述,此处将不做详细阐述说明。
如图7所示,本发明的一个实施方式提供了一种电子设备40。其中,该电子设备40包括存储器410、处理器420、输入/输出(input/output,i/o)接口430。其中,存储器410,用于存储指令。处理器420,用于调用存储器410存储的指令执行本发明实施例的用于神经网络训练方法或图像处理方法。其中,处理器420分别与存储器410、i/o接口430连接,例如可通过总线系统和/或其他形式的连接机构(未示出)进行连接。存储器410可用于存储程序和数据,包括本发明实施例中涉及的用于神经网络训练方法或图像处理方法的程序,处理器420通过运行存储在存储器410的程序从而执行电子设备40的各种功能应用以及数据处理。
本发明实施例中处理器420可以采用数字信号处理器(digitalsignalprocessing,dsp)、现场可编程门阵列(field-programmablegatearray,fpga)、可编程逻辑阵列(programmablelogicarray,pla)中的至少一种硬件形式来实现,所述处理器420可以是中央处理单元(centralprocessingunit,cpu)或者具有数据处理能力和/或指令执行能力的其他形式的处理单元中的一种或几种的组合。
本发明实施例中的存储器410可以包括一个或多个计算机程序产品,所述计算机程序产品可以包括各种形式的计算机可读存储介质,例如易失性存储器和/或非易失性存储器。所述易失性存储器例如可以包括随机存取存储器(randomaccessmemory,ram)和/或高速缓冲存储器(cache)等。所述非易失性存储器例如可以包括只读存储器(read-onlymemory,rom)、快闪存储器(flashmemory)、硬盘(harddiskdrive,hdd)或固态硬盘(solid-statedrive,ssd)等。
本发明实施例中,i/o接口430可用于接收输入的指令(例如数字或字符信息,以及产生与电子设备40的用户设置以及功能控制有关的键信号输入等),也可向外部输出各种信息(例如,图像或声音等)。本发明实施例中i/o接口430可包括物理键盘、功能按键(比如音量控制按键、开关按键等)、鼠标、操作杆、轨迹球、麦克风、扬声器、和触控面板等中的一个或多个。
在一些实施方式中,本发明提供了一种计算机可读存储介质,该计算机可读存储介质存储有计算机可执行指令,计算机可执行指令在由处理器执行时,执行上文所述的任何方法。
尽管在附图中以特定的顺序描述操作,但是不应将其理解为要求按照所示的特定顺序或是串行顺序来执行这些操作,或是要求执行全部所示的操作以得到期望的结果。在特定环境中,多任务和并行处理可能是有利的。
本发明的方法和装置能够利用标准编程技术来完成,利用基于规则的逻辑或者其他逻辑来实现各种方法步骤。还应当注意的是,此处以及权利要求书中使用的词语“装置”和“模块”意在包括使用一行或者多行软件代码的实现和/或硬件实现和/或用于接收输入的设备。
此处描述的任何步骤、操作或程序可以使用单独的或与其他设备组合的一个或多个硬件或软件模块来执行或实现。在一个实施方式中,软件模块使用包括包含计算机程序代码的计算机可读介质的计算机程序产品实现,其能够由计算机处理器执行用于执行任何或全部的所描述的步骤、操作或程序。
出于示例和描述的目的,已经给出了本发明实施的前述说明。前述说明并非是穷举性的也并非要将本发明限制到所公开的确切形式,根据上述教导还可能存在各种变形和修改,或者是可能从本发明的实践中得到各种变形和修改。选择和描述这些实施例是为了说明本发明的原理及其实际应用,以使得本领域的技术人员能够以适合于构思的特定用途来以各种实施方式和各种修改而利用本发明。
1.一种神经网络训练方法,其中,所述方法包括:
获取总训练集,所述总训练集包括多个类别的训练数据,其中每个所述类别包括一个或多个所述训练数据;
针对每一训练轮次,基于所述总训练集,对每个所述类别的训练数据进行采样,得到采样后的训练数据组成的子训练集;
根据神经网络对所述子训练集中的每个所述类别的所述训练数据进行特征提取,得到特征向量;
基于当前所述神经网络,确定每个所述类别的中心向量,其中所述中心向量作为锚点;
基于所述中心向量和所述特征向量获得损失函数的值,根据所述损失函数的值调整所述神经网络的参数。
2.根据权利要求1所述的神经网络训练方法,其中,所述基于当前所述神经网络,确定每个所述类别的中心向量,包括:
基于当前训练轮次的子训练集中每个所述类别的训练数据的所述特征向量,分别对所述特征向量对应类别的历史中心向量进行更新,得到每个所述类别的当前训练轮次的所述中心向量,其中所述历史中心向量为前一训练轮次的所述中心向量,对于第一训练轮次,前一训练轮次的所述中心向量为预设的每个所述类别的初始中心向量。
3.根据权利要求2所述的神经网络训练方法,其中,所述基于当前训练轮次的子训练集中每个所述类别的训练数据的所述特征向量,分别对所述特征向量对应类别的历史中心向量进行更新,采用以下方式:
分别设置所述子训练集中每个训练数据的所述特征向量的第一权重,以及所述历史中心向量的第二权重;
基于所述第一权重和所述第二权重,对所述当前训练轮次每个所述类别的训练数据的所述特征向量与对应类别的所述历史中心向量进行加权,获得对应类别的当前训练轮次的所述中心向量。
4.根据权利要求2所述的神经网络训练方法,其中,所述基于当前训练轮次的子训练集中每个所述类别的训练数据的所述特征向量,分别对所述特征向量对应类别的历史中心向量进行更新,采用以下方式:
基于每个所述类别的训练数据的所述特征向量,确定与每个所述类别对应的中心损失;
基于所述中心损失,更新对应所述类别的所述历史中心向量,得到对应所述类别的所述当前训练轮次所述中心向量。
5.根据权利要求4所述的神经网络训练方法,其中,所述基于所述中心损失,更新对应所述类别的所述历史中心向量,得到对应所述类别的所述当前训练轮次所述中心向量,包括:
基于所述中心损失,获取对应所述类别的所述历史中心向量与当前训练轮次的子训练集中对应所述类别的训练数据的所述特征向量之间的距离;
根据所述距离,确定对应所述类别的所述历史中心向量的更新量;
根据所述更新量更新对应所述类别的所述历史中心向量,得到对应所述类别的所述当前训练轮次所述中心向量。
6.根据权利要求1所述的神经网络训练方法,其中,所述基于当前所述神经网络,确定每个所述类别的中心向量,包括:
基于所述神经网络的每个所述类别的分类输出权重,确定对应类别的所述中心向量。
7.根据权利要求1所述的神经网络训练方法,其中,所述基于所述中心向量和所述特征向量获得损失函数的值,根据所述损失函数的值调整所述神经网络的参数,包括:
基于所述中心向量,在当前训练轮次的子训练集全部训练数据的所述特征向量中确定第一向量和第二向量,其中,所述第一向量为与所述中心向量类别相同的特征向量,所述第二向量为与所述中心向量类别不同的特征向量;
基于所述中心向量、所述第一向量以及所述第二向量,通过三元组损失函数得到损失值,并基于所述损失值调整所述神经网络的参数。
8.根据权利要求7所述的神经网络训练方法,其中,所述第一向量为与所述中心向量之间距离最远且与所述中心向量类别相同的特征向量;所述第二向量为与所述中心向量之间距离最近且与所述中心向量类别不同的特征向量。
9.一种图像处理方法,其中,所述方法包括:
获取图像;
通过神经网络进行图像识别,得到所述图像的分类结果,其中所述神经网络通过权利要求1-8任一项所述的神经网络训练方法训练得到。
10.一种神经网络训练装置,其中,所述装置包括:
获取模块,用于获取总训练集,所述总训练集包括多个类别的训练数据,其中每个所述类别包括一个或多个所述训练数据;
采样模块,用于针对每一训练轮次,基于所述总训练集,对每个所述类别的训练数据进行采样,得到采样后的训练数据组成的子训练集;
特征提取模块,用于根据神经网络对所述子训练集中的每个所述类别的所述训练数据进行特征提取,得到特征向量;
中心确定模块,用于确定每个所述类别的当前轮次的中心向量,其中所述中心向量作为锚点;
训练模块,用于基于所述中心向量和所述特征向量获得损失函数的值,根据所述损失函数的值调整所述神经网络的参数。
11.一种图像处理装置,其中,所述装置包括:
接收模块,用于获取图像;
处理模块,用于通过神经网络进行图像识别,得到所述图像的分类结果,其中所述神经网络通过权利要求1-8任一项所述的神经网络训练方法训练得到。
12.一种电子设备,其中,所述电子设备包括:
存储器,用于存储指令;以及
处理器,用于调用所述存储器存储的指令执行权利要求1-8任一项所述的神经网络训练方法或权利要求9所述的图像处理方法。
13.一种计算机可读存储介质,其中,所述计算机可读存储介质存储有计算机可执行指令,所述计算机可执行指令在由处理器执行时,执行权利要求1-8任一项所述的神经网络训练方法或权利要求9所述的图像处理方法。
技术总结