本发明涉及深度学习技术领域,尤其涉及一种基于高斯混合模型先验变分自编码器的图像生成方法。
背景技术:
在互联网时代,机器学习发展迅速,取得了很大的成就,其中图像生成技术作为机器学习的一个分支,对我们理解图像发挥了重要的作用。图像生成模型是用于对图像进行概率建模的概率模型,而深度神经网络可以看成是一个非常复杂的、拟合能力非常强的非线性函数,可以用于搭建生成模型来估计概率密度函数的参数。图像生成模型可以用于更多不同图片样本的生成,可以用于图像信息的恢复,也可以用于不同模态的图片或者图片与文字、语音等之间的转换,还可以用于预测未来,例如可以根据视频中过去帧和当前帧预测未来帧。
变分自编码器是一种十分著名的基于深度学习的图像生成模型,它是变分推断的自然发展,它结合elbo和神经网络的优势,解决了通用场景下的推断问题,同时也解决了连续数据的生成问题。它具有很多优势,包括训练快、稳定等,从而在理论模型和工业界上都有广泛的应用。然而,标准的变分自编码器先验由于存在欠拟合问题,往往会生成比较模糊的图片。
技术实现要素:
本发明的目的是针对现有技术的缺陷,提供了一种基于高斯混合模型先验变分自编码器的图像生成方法,可以对复杂图像进行建模,生成高质量图片,这大大提高了模型的生成能力。
为了实现以上目的,本发明采用以下技术方案:
一种基于高斯混合模型先验变分自编码器的图像生成方法,包括步骤:
s1.预设生成图像训练数据集;其中,所述训练数据集由若干批次的训练数据组成;
s2.搭建基于高斯混合模型先验的变分自编码器网络;
s3.将所述预设的若干批次的训练数据上传至搭建的变分自编码器网络中,并确定所述变分自编码器网络的后验分布和先验分布;
s4.确定所述高斯混合模型中高斯分量之间的关系,得到映射函数;
s5.利用所述变分自编码器网络和得到的映射函数得到重构损失函数和kl散度函数,根据所述得到的重构损失函数和kl散度函数计算所述变分自编码器网络的后验分布和先验分布的损失函数,并对所述变分自编码器网络的参数进行更新以生成图像;
s6.当生成图像时,将伪输入作为输入图像上传至所述变分自编码器网络,得到最终生成的图片。
进一步的,所述步骤s2还包括构建变分自编码器网络中隐变量的后验分布。
进一步的,所述步骤s2中搭建的变分自编码器网络中的参数包括网络输入图像尺寸c×h×w、批次大小b,隐变量维数d,隐变量z,高斯混合数量m,伪输入α,伪输入数量k。
进一步的,所述步骤s3中将所述预设的若干批次的训练数据上传至搭建的变分自编码器网络中,其中上传的训练数据包括图像样本x={x1,x2,…,xn},其中,xi为当前批次中第i个样本,i=1,2,…b、伪输入α={α1,α2,...,αk},其中αj表示第j个伪输入,j=1,2,…k。
进一步的,所述步骤s3确定的是隐变量的后验分布以及聚合后验形式的隐变量先验分布;
所述隐变量后验分布为:
所述隐变量先验为:
其中,
进一步的,所述步骤s4中确定所述高斯混合模型中高斯分量之间的关系是通过贪心算法确定的。
进一步的,所述步骤s4具体为根据以下函数依次构造映射函数:
其中,a={β(t)|t=1,...,m-1};β(·)表示映射函数。
进一步的,所述步骤s5中得到的重构损失函数为:
其中,n表示输入图片的维度;xi表示输入样本图片第i维度的值;
进一步的,所述步骤s5中得到的kl散度函数为:
其中,lkl表示每个样本的kl距离。
进一步的,所述计算得到的损失函数为:
其中,
与现有技术相比,本发明搭建基于优化高斯混合模型先验的变分自编码器网络,训练效率高,收敛性强,该网络可以对复杂图像进行建模,生成高质量图片,这大大提高了模型的生成能力。
附图说明
图1是实施例一提供的一种基于高斯混合模型先验变分自编码器的图像生成方法流程图;
图2为实施例一提供的基于优化高斯混合模型先验的变分自编码器网络示意图。
具体实施方式
以下通过特定的具体实例说明本发明的实施方式,本领域技术人员可由本说明书所揭露的内容轻易地了解本发明的其他优点与功效。本发明还可以通过另外不同的具体实施方式加以实施或应用,本说明书中的各项细节也可以基于不同观点与应用,在没有背离本发明的精神下进行各种修饰或改变。需说明的是,在不冲突的情况下,以下实施例及实施例中的特征可以相互组合。
本发明的目的是针对现有技术的缺陷,提供了一种基于高斯混合模型先验变分自编码器的图像生成方法。
实施例一
本实施例提供一种基于高斯混合模型先验变分自编码器的图像生成方法,如图1-2所示,包括步骤:
s11.预设生成图像训练数据集;其中,所述训练数据集由若干批次的训练数据组成;
s12.搭建基于高斯混合模型先验的变分自编码器网络;
s13.将所述预设的若干批次的训练数据上传至搭建的变分自编码器网络中,并确定所述变分自编码器网络的后验分布和先验分布;
s14.确定所述高斯混合模型中高斯分量之间的关系,得到映射函数;
s15.利用所述变分自编码器网络和得到的映射函数得到重构损失函数和kl散度函数,根据所述得到的重构损失函数和kl散度函数计算所述变分自编码器网络的后验分布和先验分布的损失函数,并对所述变分自编码器网络的参数进行更新以生成图像;
s16.当生成图像时,将伪输入作为输入图像上传至所述变分自编码器网络,得到最终生成的图片。
在步骤s11中,预设生成图像训练数据集;其中,所述训练数据集由若干批次的训练数据组成。
准备符合要求的生成图像训练数据集,定义每一批次大小为b的训练数据,由b张图像样本组成训练批次{x1,x2,…,xn}。
在步骤s12中,搭建基于高斯混合模型先验的变分自编码器网络。
搭建基于高斯混合模型先验的变分自编码器网络中隐变量的后验分布,如图2所示,该变分自编码器网络由一个自底向上的推断模型和一个自顶向下的生成模型组成,且具有正向、反向传播的功能。其中搭建的变分自编码器网络中的参数包括网络输入图像尺寸c×h×w、批次大小b,隐变量维数d,隐变量z,高斯混合数量m,伪输入α,伪输入数量k。在本实施例中,取d=40,m=3,k=500。
需要说明的是,本实施例在变分自编码器的基础上,建立高斯混合模型。与变分自编码器不同,该隐变量的后验分布是通过高斯混合模型而不是单个高斯模型构建的,为了简化计算每个分量的协方差矩阵是对角的。
在步骤s13中,将所述预设的若干批次的训练数据上传至搭建的变分自编码器网络中,并确定所述变分自编码器网络的后验分布和先验分布。
将预设的若干批次的训练数据上传至搭建的变分自编码器网络中,而对于每一批次送入上述变分自编码器网络的训练批次,其包含图像样本,表述为x={x1,x2,…,xn},其中,xi为该批次中第i个样本,i=1,2,…b,以及可学习的伪输入{α1,α2,...,αk},其中αj表示第j个伪输入,j=1,2,…k;确定上述网络的隐变量的后验分布q(z|x)和聚合后验形式的隐变量先验分布p(z)。
隐变量后验分布为:
隐变量先验为:
其中,
在步骤s14中,确定所述高斯混合模型中高斯分量之间的关系,得到映射函数。
使用贪心算法确定两个高斯模型中的高斯分量之间的对应关系,得出映射函数β(·)。
在本实施例中,根据高斯权值πm从大到小排序参照高斯混合模型中的所有高斯分量,使得c1≥c2≥…≥cm,令i=1,
根据公式
其中,a=a∪{β(m)},如果m<m,m=m 1,则返回依次构造映射函数的步骤中;否则结束,继续执行步骤s15。
在步骤s15中,利用所述变分自编码器网络和得到的映射函数得到重构损失函数和kl散度函数,根据所述得到的重构损失函数和kl散度函数计算所述变分自编码器网络的后验分布和先验分布的损失函数,并对所述变分自编码器网络的参数进行更新以生成图像。
在本实施例中,对上述变分自编码器网络通过重构损失函数和kl散度函数,并分别对上述变分自编码器网络的输入x和输出
计算每个样本的重构损失函数为:
其中,n表示输入图片的维度;xi表示输入样本图片第i维度的值;
使用映射函数β(.),计算每个样本的kl散度函数为:
其中,lkl表示每个样本的kl距离。
通过计算出的每个样本的重构损失函数和每个样本的kl散度函数进而计算变分自编码器网络的输入x和输出
其中,
在步骤s16中,当生成图像时,将伪输入作为输入图像上传至所述变分自编码器网络,得到最终生成的图片。
生成图像时,将伪输入作为输入图像送入上述网络,即可输出高质量的生成图片。
在本实施例中,所涉及名词的解释如下:
高斯混合模型就是用高斯概率密度函数(正态分布曲线)精确地量化事物,它是一个将事物分解为若干的基于高斯概率密度函数(正态分布曲线)形成的模型。
贪心算法(又称贪婪算法)是指,在对问题求解时,总是做出在当前看来是最好的选择。也就是说,不从整体最优上加以考虑,他所做出的是在某种意义上的局部最优解。贪婪算法是一种改进了的分级处理方法。其核心是根据题意选取一种量度标准。然后将这多个输入排成这种量度标准所要求的顺序,按这种顺序一次输入一个量。如果这个输入和当前已构成在这种量度意义下的部分最佳解加在一起不能产生一个可行解,则不把此输入加到这部分解中。这种能够得到某种量度意义下最优解的分级处理方法称为贪婪算法。
与现有技术相比,本实施例搭建基于优化高斯混合模型先验的变分自编码器网络,训练效率高,收敛性强,该网络可以对复杂图像进行建模,生成高质量图片,这大大提高了模型的生成能力。
注意,上述仅为本发明的较佳实施例及所运用技术原理。本领域技术人员会理解,本发明不限于这里所述的特定实施例,对本领域技术人员来说能够进行各种明显的变化、重新调整和替代而不会脱离本发明的保护范围。因此,虽然通过以上实施例对本发明进行了较为详细的说明,但是本发明不仅仅限于以上实施例,在不脱离本发明构思的情况下,还可以包括更多其他等效实施例,而本发明的范围由所附的权利要求范围决定。
1.一种基于高斯混合模型先验变分自编码器的图像生成方法,其特征在于,包括步骤:
s1.预设生成图像训练数据集;其中,所述训练数据集由若干批次的训练数据组成;
s2.搭建基于高斯混合模型先验的变分自编码器网络;
s3.将所述预设的若干批次的训练数据上传至搭建的变分自编码器网络中,并确定所述变分自编码器网络的后验分布和先验分布;
s4.确定所述高斯混合模型中高斯分量之间的关系,得到映射函数;
s5.利用所述变分自编码器网络和得到的映射函数得到重构损失函数和kl散度函数,根据所述得到的重构损失函数和kl散度函数计算所述变分自编码器网络的后验分布和先验分布的损失函数,并对所述变分自编码器网络的参数进行更新以生成图像;
s6.当生成图像时,将伪输入作为输入图像上传至所述变分自编码器网络,得到最终生成的图片。
2.根据权利要求1所述的一种基于高斯混合模型先验变分自编码器的图像生成方法,其特征在于,所述步骤s2还包括构建变分自编码器网络中隐变量的后验分布。
3.根据权利要求2所述的一种基于高斯混合模型先验变分自编码器的图像生成方法,其特征在于,所述步骤s2中搭建的变分自编码器网络中的参数包括网络输入图像尺寸c×h×w、批次大小b,隐变量维数d,隐变量z,高斯混合数量m,伪输入α,伪输入数量k。
4.根据权利要求3所述的一种基于高斯混合模型先验变分自编码器的图像生成方法,其特征在于,所述步骤s3中将所述预设的若干批次的训练数据上传至搭建的变分自编码器网络中,其中上传的训练数据包括图像样本x={x1,x2,…,xn},其中,xi为当前批次中第i个样本,i=1,2,…b、伪输入α={α1,α2,...,αk},其中αj表示第j个伪输入,j=1,2,…k。
5.根据权利要求4所述的一种基于高斯混合模型先验变分自编码器的图像生成方法,其特征在于,所述步骤s3确定的是隐变量的后验分布以及聚合后验形式的隐变量先验分布;
所述隐变量后验分布为:
所述隐变量先验为:
其中,
6.根据权利要求5所述的一种基于高斯混合模型先验变分自编码器的图像生成方法,其特征在于,所述步骤s4中确定所述高斯混合模型中高斯分量之间的关系是通过贪心算法确定的。
7.根据权利要求6所述的一种基于高斯混合模型先验变分自编码器的图像生成方法,其特征在于,所述步骤s4具体为根据以下函数依次构造映射函数:
其中,a={β(t)|t=1,...,m-1};β(·)表示映射函数。
8.根据权利要求7所述的一种基于高斯混合模型先验变分自编码器的图像生成方法,其特征在于,所述步骤s5中得到的重构损失函数为:
其中,n表示输入图片的维度;xi表示输入样本图片第i维度的值;
9.根据权利要求8所述的一种基于高斯混合模型先验变分自编码器的图像生成方法,其特征在于,所述步骤s5中得到的kl散度函数为:
其中,lkl表示每个样本的kl距离。
10.根据权利要求9所述的一种基于高斯混合模型先验变分自编码器的图像生成方法,其特征在于,所述计算得到的损失函数为:
其中,