技术领域
本发明涉及计算机视觉领域,更具体地,涉及一种基于集成学习的零样本哈希检索方法。
背景技术
随着互联网的快速发展,各种各样的数据呈爆炸式增长,包括图片、文字、视频等各类信息。这就导致了人们将会耗费大量时间来寻找自己感兴趣的内容,当人们带着目标来浏览网页或者使用手机软件时,面对如此巨大的数据库,成千上万条信息陈列在界面上,往往很难通过肉眼来迅速找到所有目标信息,因此检索系统应运而生。
图片检索是检索系统的一个重要组成部分。哈希技术一直被用于快速图片检索领域。将图片标签作为监督信息,来训练深度神经网络,能使哈希技术的效果变得更好。然而,网络上每天都在产生新的概念、新的图片,这为检索系统带来了新的挑战:零样本问题,也就是当训练好的模型遇到从未见过的新类别图片时,检索效果会变得非常差。为了解决这个问题,在本发明中提出了一个零样本的哈希检索方法,利用集成学习的思想,提升模型的泛化能力,从而使模型在遇到新类别的图片时,也能有良好的检索效果。
集成学习主要是组合多个弱监督模型以得到一个更好更全面的强监督模型,它是将几种机器学习技术组合成一个预测模型的元算法,以达到减小方差、偏差或改进预测的效果。集成学习方法主要分为两大类:关于数据集的集成、关于模型融合的集成。其中,关于数据集的集成是指,利用Bootstrap方法进行抽样得到多个数据集,或者通过更新各个样本权重的方式来改变数据分布,分别训练多个模型进行组合,比如Bagging和Boosting方法;关于模型融合的集成是指,采取不同的方式组合多个学习器,从而获得更好的效果,对于回归问题,可以使用平均法;对于分类问题,可以使用多数投票法。
申请号为201510200864.2的专利说明书中公开了一种基于集成哈希编码的快速图像检索方法,本申请首先提取训练图像和查询图像的SIFT特征,并利用M种哈希算法对训练图像进行初始哈希编码;然后利用集成学习中的一致性约束准则对初始哈希编码结果进行再学习,得到集成哈希映射矩阵;最后重新对训练图像和查询图像进行集成哈希编码,并在集成哈希编码的基础上通过计算查询图像与训练图像之间的汉明距离来进行快速检索。本发明中的集成哈希编码能够同时融合不同哈希算法的特点和优势,解决了单一哈希算法判别力不足和适用范围的局限性问题,从而使得图像的快速检索更加准确和高效。然而,该专利无法实现结合了划分数据集和组合模型两种策略来提升模型的泛化能力,检索效果也大大提升。
发明内容
本发明提供一种基于集成学习的零样本哈希检索方法,该方法结合了划分数据集和组合模型两种策略来提升模型的泛化能力,检索效果也大大提升。
为了达到上述技术效果,本发明的技术方案如下:
一种基于集成学习的零样本哈希检索方法,包括以下步骤:
S1:将训练集按照类别标签分为类别不重叠的两部分A和B;
S2:分别用A、B和A+B作为训练数据,通过VGG-16模型和一层全连接层,得到训练样本的哈希码;
S3:利用三元组损失,得到训练过程中的损失;
S4:利用SWA方法训练更新网络,得到收敛后的模型;
S5:步骤S2中的3个数据集训练可以得到3个不同的模型,求它们求平均值,得到最终的集成模型;
S6:计算集成模型在测试集上的检索结果。
进一步地,所述步骤S1的具体过程是:
将训练集按照类别标签分为类别不重叠的两部分A和B。对于数据集cifar10来说,会把1~9类划分为训练集,第10类划分为测试集。在训练的过程中,需要把训练集按照类别划分为A(第1~5类)和B(第6~9类)两部分。
进一步地,所述步骤S2的哈希模型的设计是:
S21:首先,分别将数据集A、B和A+B分别作为训练集,训练出3个不同的模型。具体的训练步骤如下;
S22:使用VGG-16模型提取训练集中的图像样本的高维实数特征(4096维);
S23:将S22步骤得到的高维实数特征输入到全连接层和tanh激活函数后,得到实数向量v,再对v进行二值化(大于0的元素设置为1,小于0的元素设置为0),得到二进制码b,即哈希码。量化公式如下:
进一步地,所述步骤S3的三元组损失具体步骤是:
S31:在每个batch的训练样本中,构造三元组<I,I
S32:三元组损失计算公式如下所示:
其中超参数margin,表示I和负样本I
进一步地,所述步骤S4中SWA的训练过程是:
S41:先用VGG-16的预训练模型参数初始化特征提取模型,然后随机初始化最后一层全连接层(用于获得哈希码的全连接层),得到初始化后的权重
S42:迭代n轮,训练模型;
S43:对于第i轮迭代,依次更新学习率和模型权重,更新公式如下:
循环学习率:
更新网络权重:
S44:对于第i轮迭代,若mod(i,c)=0,其中,c是一个预设的超参数,表示循环长度,用滑动平均的方式更新最终的网络权重w
n
进一步地,步骤S5中得到集成模型的过程为:
S51:用训练集A、B和A+B分别训练模型,得到三个不同的模型权重w
进一步地,步骤S6中,计算集成模型在测试集上的检索准确性(mAP)的过程如下:
S61:计算查询图像哈希码与数据库中所有图像哈希码之间的汉明距离(汉明距离=将对应位上的哈希码字做异或操作并求和,即不同取值的码元数目);
S62:将数据库中的哈希码按照与查询图像的汉明距离从小到大排序,根据标签依次判断该图像是否与文本属于同类,同类即检索正确,以此计算AP值。
整体的检索指标:
平均准确率:
其中i代表第i张测试集图片;I是测试集图片的数目。k表示在用第i张图片作为查询图片时得到的检索列表中的排序位置;P
与现有技术相比,本发明技术方案的有益效果是:
本发明将集成学习的方法应用于零样本图片检索问题。使用VGG-16提取图片的高维实数特征,然后用全连接层和激活函数将高维实数特征转化为低维二进制哈希码,在保证检索效果的前提下,减少了存储空间。之后,利用集成学习的训练方法,更新哈希模型,从而使模型具有更强的泛化能力,使模型在新类别的图片上的检索效果也大大提升。
附图说明
图1为本发明的算法流程图;
图2为本发明的SWA方法示意图。
具体实施方式
附图仅用于示例性说明,不能理解为对本专利的限制;
为了更好说明本实施例,附图某些部件会有省略、放大或缩小,并不代表实际产品的尺寸;
对于本领域技术人员来说,附图中某些公知结构及其说明可能省略是可以理解的。
下面结合附图和实施例对本发明的技术方案做进一步的说明。
如图1所示,一种基于集成学习的零样本哈希检索方法,包括以下步骤:
S1:将训练集按照类别标签分为类别不重叠的两部分A和B;
S2:分别用A、B和A+B作为训练数据,通过VGG-16模型和一层全连接层,得到训练样本的哈希码;
S3:利用三元组损失,得到训练过程中的损失;
S4:利用SWA方法训练更新网络,得到收敛后的模型;
S5:步骤S2中的3个数据集训练可以得到3个不同的模型,求它们求平均值,得到最终的集成模型;
S6:计算集成模型在测试集上的检索结果。
步骤S1的具体过程是:
将训练集按照类别标签分为类别不重叠的两部分A和B。对于数据集cifar10来说,会把1~9类划分为训练集,第10类划分为测试集。在训练的过程中,需要把训练集按照类别划分为A(第1~5类)和B(第6~9类)两部分。
步骤S2的哈希模型的设计是:
S21:首先,分别将数据集A、B和A+B分别作为训练集,训练出3个不同的模型。具体的训练步骤如下;
S22:使用VGG-16模型提取训练集中的图像样本的高维实数特征(4096维);
S23:将S22步骤得到的高维实数特征输入到全连接层和tanh激活函数后,得到实数向量v,再对v进行二值化(大于0的元素设置为1,小于0的元素设置为0),得到二进制码b,即哈希码。量化公式如下:
步骤S3的三元组损失具体步骤是:
s31:在每个batch的训练样本中,构造三元组<I,I
S32:三元组损失计算公式如下所示:
其中超参数margin,表示I和负样本I
如图2所示,步骤S4中SWA的训练过程是:
S41:先用VGG-16的预训练模型参数初始化特征提取模型,然后随机初始化最后一层全连接层(用于获得哈希码的全连接层),得到初始化后的权重
S42:迭代n轮,训练模型;
S43:对于第i轮迭代,依次更新学习率和模型权重,更新公式如下:
循环学习率:
更新网络权重:
S44:对于第i轮迭代,若mod(i,c)=0,其中,c是一个预设的超参数,表示循环长度,用滑动平均的方式更新最终的网络权重w
n
步骤S5中得到集成模型的过程为:
S51:用训练集A、B和A+B分别训练模型,得到三个不同的模型权重w
进一步地,步骤S6中,计算集成模型在测试集上的检索准确性(mAP)的过程如下:
S61:计算查询图像哈希码与数据库中所有图像哈希码之间的汉明距离(汉明距离=将对应位上的哈希码字做异或操作并求和,即不同取值的码元数目);
S62:将数据库中的哈希码按照与查询图像的汉明距离从小到大排序,根据标签依次判断该图像是否与文本属于同类,同类即检索正确,以此计算AP值。
整体的检索指标:
平均准确率:
其中i代表第i张测试集图片;I是测试集图片的数目。k表示在用第i张图片作为查询图片时得到的检索列表中的排序位置;P
相同或相似的标号对应相同或相似的部件;
附图中描述位置关系的用于仅用于示例性说明,不能理解为对本专利的限制;
显然,本发明的上述实施例仅仅是为清楚地说明本发明所作的举例,而并非是对本发明的实施方式的限定。对于所属领域的普通技术人员来说,在上述说明的基础上还可以做出其它不同形式的变化或变动。这里无需也无法对所有的实施方式予以穷举。凡在本发明的精神和原则之内所作的任何修改、等同替换和改进等,均应包含在本发明权利要求的保护范围之内。
机译: 一种用于基于多个输入样本提供多个输出样本的信号处理装置,以及用于基于多个输入样本提供多个输出样本的方法
机译: 一种用于基于一组输入样本提供多个输出样本的信号处理装置,以及用于基于一组输入样本提供多个输出样本的方法
机译: 基于变长深度哈希学习的图像检索方法