本发明涉及人工智能领域和计算机视觉领域,是一种基于强化学习和注意力机制的图像分类方法。
背景技术:
图像分类是计算机视觉研究中最基本的问题,是后续高层次视觉任务的基础。
近年来,受到人的视觉生物系统的启发,注意力机制越来越受到学术界和工业界研究的重视,并且取得了很多非常好的研究成果,在业界得到了广泛的应用。注意力机制attention从本质上讲和人类的选择性机制类似,核心目标也是从众多信息中选择出对当前任务更关键的信息。目前attention的形式主要分为软注意力softattention和硬注意力hardattention两种形式。softattention是参数化的,因此可导,能够被嵌入到模型中去,梯度可以直接反向传播到模型其他部分,这种方式简单直接,所以在目前的研究和应用中占据着主流地位。hardattention不会选择全部数据做为其输入,而是会依概率来采样一部分数据来进行计算,为了实现梯度的反向传播,需要采用蒙特卡洛采样的方法来估计梯度,然后利用强化学习算法进行训练,这种方法较为复杂,在业界的研究不是很多。
deepmind2014年针对经典的mnist手写体数字识别提出了recurrentmodelsofvisualattention(ram)模型,ram将attentionproblem看做是目标引导的序列决策过程,能够和视觉环境交互。在每一个时间点,智能体agent只能根据有带宽限制的感知器来观察全局,即只能在一个局部区域或者狭窄的频域范围进行信息的提取。theagent可以自主的控制如何布置感知器的资源,即:选择感知的位置区域。该agent也可以通过执行actions来影响环境的真实状态。由于该环境只是部分可观察,所以他需要额外的信息来辅助其进行决定如何行动和如何最有效的布置感知器。每一步,agent都会收到奖励或者惩罚,agent的目标就是将奖励最大化。该模型可以看作是首次将强化学习应用到图像分类中,模型在设计上存在一些不足,比如需要固定的输入步长、缺少全局信息等,因此在应到其他复杂分类问题时效果不好。
wangf等人在arxiv发表的论文《residualattentionnetworkforimageclassification》,提出了一种residualattentionnetwork,是attentionmodule的堆叠。在每个module中均使用bottom-uptop-down结构,利用残差机制使得网络深度可以进一步扩展。这种注意力机制是可微的,可以通过反向传播训练。
kelvinxu等人在2015年发表的论文《show,attendandtell:neuralimagecaptiongenerationwithvisualattention》中,给出了hardattention和softattention的定义。在hardattention机制中,权重
将
相比之下,在softattention机制中,权重
该定义是参数化的,因此可导,可以直接用于训练,比较简单,因此成为研究的主流。
huj等人在arxiv发表了论文《squeeze-and-excitationnetworks》,senet的核心思想在于通过网络根据loss去学习特征权重,使得有效的特征图(featuremap)权重大,无效或效果小的featuremap权重小的方式训练模型达到更好的结果。
综上所述,attention机制在计算机视觉研究中有了广泛的研究,它的本质是让模型能够像人类的视觉系统一样,更关注图像中重要的区域,而对那些不重要的信息尽可能忽略。
目前业界的研究应用主要以softattention为主,hardattention的研究相对要少很多。这两种attention机制各有各的优点,然而在图片分类任务实际需求中,还有明显的不足:
1.目前已有的hardattention框架或者需要标准的序列数据输入,或者采用固定长度的图像局部序列数据且这些数据缺少全局信息的控制,导致应用在复杂的图像分类任务时效果不好;
2.图像中有许多数据对于分类来说是冗余的,甚至可能只有很少一部分数据是有价值的,而softattention机制需要所有的数据参与计算,这样不仅仅增加了不必要的计算量,而且还有可能引入噪声,影响分类结果;
3.softattention机制在解释性上不够好,在图像分类任务中,有时候我们需要关注模型学习到了什么,或者是学习到图像中哪个区域是重要的,softattention由于会对所有的区域生成一个权重分布,不像hardattention那样明确选择特定区域,因此在解释性上不如hardattention。
技术实现要素:
针对上述现有技术中存在的不足,本发明的目的是提供一种基于强化学习的图像分类方法。它从注意力和强化学习入手,自动对输入的特征数据进行选择,然后利用这些选择的特征进行最终的分类,该方法不仅排除了大量的冗余噪声数据,提高了分类效果,而且有很好的可解释性。
为了达到上述发明目的,本发明的技术方案以如下方式实现:
一种基于强化学习的图像分类方法,它使用依次连接的卷积神经网络(cnn)以及特征选择和分类模块(fscm)。所述特征选择和分类模块包括工作模块(worker)和管理模块(manager)两个子模块,工作模块由长短时记忆模型(lstm)和与之相连接的动作模型组成,管理模块使用长短时记忆模型。其方法步骤为:
1)将输入图片经过卷积神经网络生成特征图,并按行序转为通道序列数据。
2)特征选择和分类模块中的工作模块利用长短时记忆模型对输入的当前数据(currentdata)依次进行处理,生成所有输入数据的权重概率分布。首先,初始化工作模块中长短时记忆模型的
3)将当前数据
a)训练阶段:工作模块采用近端策略优化算法(ppo)联合自适应时刻估计算法(adam)的方式进行训练,管理模块采用自适应时刻估计算法进行训练。计算时引入全局信息,并将选择的序列数据最大长度设置为8,在此基础上引入一个终止动作(terminal),模型的具体定义为下式:
其中,
动作模型采用硬注意力机制(hardattention)从权重概率分布中选择数据进行采样,其训练时采用如下方式采样:
获得当前应采取的动作(action)。
b)预测阶段:预测过程采用前向计算方式,工作模块从卷积神经网络输入的特征数据中进行自动选择。首先由长短时记忆模型按照公式(1)计算,然后动作模型采用硬注意力机制选取概率最大的权重值对应的数据,即:
获得当前应采取的动作,若动作不为终止动作,则跟据动作获得对应的通道数据,作为下一时刻待处理数据(nextdata)
4)将下一时刻待处理数据
5)将收集到的当前数据序列(currentdatasequence)送入到管理模块进行分类。
a)训练阶段:工作模块训练需要的奖励函数由管理模块提供,奖励函数定义如下:
其中,
根据分类结果和公式(4)分别训练工作模块和管理模块。
b)预测阶段:将分类结果输出。
本发明由于采用了上述方法步骤,同现有技术相比具有如下优点:
1.模型自动选择图像的某些特征数据进行分类,很大程度上减少了冗余数据,提高了模型的性能和计算效率。
2.通过特征选择和分类模块,可以很直观看到图片哪些区域参与了分类任务,可以了解到模型学到了什么,相比于softattention有更好的解释性。
3.工作模块处理输入数据增加了全局信息,因此可以生成更合理的采样分布,另外,引入终止状态使得模型更加灵活、高效。
下面结合附图和具体实施方式对本发明做进一步说明。
附图说明
图1为本发明方法整体流程图;
图2为本发明实施例中卷积神经网络结构示意图;
图3为本发明实施例中特征选择和分类模块流程图。
具体实施方式
参看图1至图3,本发明基于强化学习的图像分类方法,它使用依次连接的卷积神经网络以及特征选择和分类模块。特征选择和分类模块包括工作模块和管理模块两个子模块,工作模块由长短时记忆模型和与之相连接的动作模型组成,管理模块使用长短时记忆模型。其方法步骤为:
1)将输入图片经过卷积神经网络生成特征图,并按行序转为通道序列数据。
卷积神经网络包括13个卷积(conv)层,4个池化(pool)层。其中,输入图像大小为num*3*224*224,所有的conv层卷积核的大小都为3*3,步长(stride),每层的通道数目如下:
conv1:16
conv2:64
conv3:128
conv4:64
conv5:128
conv6:256
conv7:128
conv8:256
conv9:512
conv10:512
conv11:256
conv12:512
conv13:512
池化层的核的大小为2*2,步长为2。卷积神经网络最后输出的特征图大小为14*14*512。
2)特征选择和分类模块中的工作模块利用长短时记忆模型对输入的当前数据依次进行处理生成所有输入数据的权重概率分布。首先,初始化工作模块中长短时记忆模型的
3)将当前数据
a)训练阶段:工作模块采用近端策略优化算法联合自适应时刻估计算法的方式进行训练,管理模块采用自适应时刻估计算法进行训练。计算时引入全局信息,并将选择的序列数据最大长度设置为8,在此基础上引入一个终止动作,模型的具体定义为下式:
其中,
动作模型采用硬注意力机制从权重概率分布中选择数据进行采样,其训练时采用如下方式采样:
获得当前应采取的动作。
b)预测阶段:预测过程采用前向计算方式,工作模块从卷积神经网络输入的特征数据中进行自动选择。首先由长短时记忆模型按照公式(1)计算,然后动作模型采用硬注意力机制选取概率最大的权重值对应的数据,即:
获得当前应采取的动作,若动作不为终止动作,则跟据动作获得对应的通道数据,作为下一时刻待处理数据
4)将下一时刻待处理数据
5)将收集到的当前数据序列送入到管理模块进行分类。
a)训练阶段:工作模块训练需要的奖励函数由管理模块提供,奖励函数定义如下:
其中,
根据分类结果和公式(4)分别训练工作模块和管理模块。
b)预测阶段:将分类结果输出。
本发明实施例仅为说明本申请技术方案,本领域技术人员在本申请基础上所做的同类替代,如将卷积神经网络替换为结合其他深度学习模型或者机器学习的方案,将处理序列数据的长短时记忆模型替换为其他方法,或是将硬注意力机制采样方法替换为其他方法等均应属于本申请保护的范围。
1.一种基于强化学习的图像分类方法,它使用依次连接的卷积神经网络以及特征选择和分类模块,所述特征选择和分类模块包括工作模块和管理模块两个子模块,工作模块由长短时记忆模型和与之相连接的动作模型组成,管理模块使用长短时记忆模型;其方法步骤为:
1)将输入图片经过卷积神经网络生成特征图,并按行序转为通道序列数据;
2)特征选择和分类模块中的工作模块利用长短时记忆模型对输入的当前数据依次进行处理,生成所有输入数据的权重概率分布,首先初始化工作模块中长短时记忆模型的
3)将当前数据
a)训练阶段:工作模块采用近端策略优化算法联合自适应时刻估计算法的方式进行训练,管理模块采用自适应时刻估计算法进行训练;计算时引入全局信息,并将选择的序列数据最大长度设置为8,在此基础上引入一个终止动作,模型的具体定义为下式:
其中,
动作模型采用硬注意力机制从权重概率分布中选择数据进行采样,其训练时采用如下方式采样:
获得当前应采取的动作;
b)预测阶段:预测过程采用前向计算方式,工作模块从卷积神经网络输入的特征数据中进行自动选择,首先由长短时记忆模型按照公式(1)计算,然后动作模型采用硬注意力机制选取概率最大的权重值对应的数据,即:
获得当前应采取的动作,若动作不为终止动作,则跟据动作获得对应的通道数据,作为下一时刻待处理数据
4)将下一时刻待处理数据
5)将收集到的当前数据序列送入到管理模块进行分类;
a)训练阶段:工作模块训练需要的奖励函数由管理模块提供,奖励函数定义如下:
其中,
根据分类结果和公式(4)分别训练工作模块和管理模块;
b)预测阶段:将分类结果输出。
技术总结