首页> 中国专利> 一种基于域自适应的多初始值元学习框架及方法

一种基于域自适应的多初始值元学习框架及方法

摘要

本发明提供一种基于域自适应的多初始值元学习框架及方法,框架包括跨域编码器,将输入数据通过共有编码器编码为共有特征向量,通过私有编码器编码为私有特征向量;跨域调制网络,将共有特征向量编码为域公用调制向量,私有特征向量编码为域专用调制向量;元分离网络,用于在源域和目标域中更新元学习器,其中元学习器的参数分为由域公用调制向量调制的参数公共部分和由域专用调制向量调制的参数私有部分,该学习框架及方法可在一定程度上提高算法在少样本问题中的准确率,并广泛适用于跨域数据的元学习中。

著录项

  • 公开/公告号CN112734049A

    专利类型发明专利

  • 公开/公告日2021-04-30

    原文格式PDF

  • 申请/专利权人 西湖大学;

    申请/专利号CN202110210507.X

  • 发明设计人 陈政聿;王东林;

    申请日2021-02-25

  • 分类号G06N20/00(20190101);G06N3/08(20060101);

  • 代理机构32260 无锡市汇诚永信专利代理事务所(普通合伙);

  • 代理人李珍珍

  • 地址 310000 浙江省杭州市西湖区转塘街道石龙山街18号

  • 入库时间 2023-06-19 10:48:02

说明书

技术领域

本发明涉及元学习技术领域,具体涉及一种基于域自适应的多初始值元学习框架及方法。

背景技术

人工智能在各种技术领域中都有着广泛的应用,其存在的基本问题是其无法像人类一样高效地学习,需要不断地用训练样本对其进行训练学习,训练样本越完善越多,则训练得到的人工智能模型的结果就越好。然而在实际过程中经常会出现训练样本数量不足的问题,因此如何进行有效的少样本学习,已成为人工智能学习领域的一个备受关注的问题。

元学习是解决少样本学习的一种有效方法,元学习也可被理解为“学习如何学习”,现有的元学习方法包括基于度量学习的方法、基于元优化的方法、基于循环模型的方法,但是这些元学习方法的损失函数仅与特定任务有关,而没有域无关或者域自适应的约束,因此,这些方法在单域任务上表现良好,而在跨域数据上都存在着泛化性能不足的缺陷。

具体而言,元测试阶段和元训练阶段中不同类别产生的分布不同,导致了元学习方法存在领域转换的问题,尽管多初始值技术在识别任务模式方面取得了成功,但依旧无法解决由不同类别分布产生的领域转移导致的其在跨域领域存在泛化不足的缺陷。也就是说,现有的领域适应方法只能使元学习方法适应单模态元测试领域,而不能适应多模态元测试领域,到目前为止,如何缓解多模式设置中元训练和元测试阶段之间的领域转换仍然是一个挑战。

总计而言,目前的元学习方法无法很好地适用于跨域数据的学习,也就限制了其在集合跨域数据的应用场景的应用。

发明内容

本发明的目的在于提供一种基于域自适应的多初始值元学习框架及方法,可广泛适用于跨域数据的元学习,且在一定程度上提高算法在少样本问题中的准确率。

为实现上述目的,本技术方案提供一种基于域自适应的多初始值元学习框架及方法,该基于域自适应的多初始化元学习框架包括:

跨域编码器,将输入数据通过共有编码器编码为共有特征向量,通过私有编码器编码为私有特征向量;

跨域调制网络,将共有特征向量编码为域公用调制向量,私有特征向量编码为域专用调制向量;

元分离网络,用于在源域和目标域中更新元学习器,其中元学习器的参数分为由域公用调制向量调制的参数公共部分和由域专用调制向量调制的参数私有部分。

其中跨域编码器的损失函数的计算公式为:

L

其中L

重构误差损失函数L

该公式中,其中

跨域差异损失函数L

其中

跨域相似损失函数L

其中m

其中跨域调制网络将共有特征向量编码为域公用调制向量的公式为:

其中

其中跨域调制网络将私有特征向量编码为域专用调制向量的公式为:

其中

其中元分离网络的更新参数公共部分的公式如下:

其中

元分离网络的更新参数私有部分的公式如下:

其中

具体的,由于神经网络的低层和高层显示出不同类型的信息。低层网络往往具有良好的迁移性,其不针对特定任务,而具有针对不同任务的通用性,而网络从低到高的过程中,其特征也从一般性过渡到特定性。本发明基于这一重要现象,设计跨域编码器、跨域调制网络、元分离网络三个部分。其中元分离网络的参数公共部分为低层网络参数,不同任务的低层网络跨域共享以进行元学习和联合训练,并由域公用调制向量调制;元分离网络的参数私有部分为高层网络参数,特定于单个任务,因此是在不同域分开训练,并由域专用调制向量调制。通过有效利用低层网络的通用性,本发明可以实现提高元学习对跨域数据的泛化性能;通过有效利用高层网络的特定性,且本发明也可以实现提高元学习对跨域数据的预测速度和效率,由于元学习广泛用于解决少样本问题,因此本发明可以应用于各应用领域的少样本问题。

该基于域自适应的多初始化元学习模型的结构如图1所示,不同域的数据 X

第二方面,本方案提供一种基于域自适应的多初始值元学习方法,利用上述基于域自适应的多初始值元学习模型进行学习,包括以下步骤:

初始化:随机初始化元分离网络参数和跨域编码器参数、跨域调制网络参数;

数据采样:从源域和目标域数据中分别采样支持集和查询集数据;

自适应于支持集的元分离网络参数获取:将支持集数据输入跨域编码器,输出共有特征向量和私有特征向量;将共有特征向量和私有特征向量输入跨域调制网络,输出域公用调制向量和域专用调制向量;将域公用调制向量调制得到元分离网络参数公共部分,将域专用调制向量调制成元分离网络参数私有部分,得到调制后的元分离网络;将支持集数据输入调制后的元分离网络,计算第一网络梯度;将第一网络梯度用于更新元分离网络参数,得到自适应于支持集的元分离网络参数,遍历源域和目标域的支持集数据。

更新网络参数:根据查询集在自适应于支持集的元分离网络上的误差,计算第二梯度,并更新元分离网络参数公共部分,元分离网络参数私有部分,跨域编码器,跨域调制网络后回归进行自适应于支持集的元分离网络参数获取步骤,直到网络收敛,输出所有网络参数。

其中计算第一网络梯度的公式如下:

其中θ为初始化元分离网络参数,T

其中更新元分离网络参数公共部分:

更新元分离网络参数私有部分:

更新跨域编码器

更新跨域调制网络:

相较现有技术,本技术方案的有益效果和特点如下:首先,提出了一种新的单分散网络结构的基于域自适应的多初始值元学习方法来提高元学习在多模态任务上的性能,针对多初始化域移位问题,提出了一种基于薄膜网络的跨域调制网络,将任务编码为域公共和域私有调制向量。基于生成的调制向量,一种新的元分离网络(MSN)提出了在源和目标域更新元学习器,元学习器的参数分为由域公用调制向量调制的参数公共部分和由域专用调制向量调制的参数私有部分。此外,元学习中共享的公共参数是由前几层学习的,而私有参数是由后几层学习的。关键的原因是较低的层次可能生成通用的特征,而较高的层次可能学习特定的特征。此外,将不等式测度纳入元学习的更新过程中,以进一步提高元学习的泛化能力。

附图说明

图1是根据本发明的一种基于域自适应的多初始值元学习模型框架示意图。

图2是基于域自适应的多初始值元学习方法的伪代码图。

图3是强化学习实验用的实验图。

具体实施方式

下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员所获得的所有其他实施例,都属于本发明保护的范围。

实施例:

本申请人采用了大量的实验,包括回归、图像分类和强化学习(RL),以评估本方案提出的方法在各种多模态少样本学习任务中的应用,为了进行比较,本申请人考虑了以下元学习方法作为参考:

1.MAML:MAML是传统的模型无关元学习算法的代表,它们已经应用于广泛的研究领域;

2.Multi-MAML:由多个MAML模型组成,每个模型都是根据从单个模态中采样的任务进行专门训练的。值得注意的是,Multi-MAML是在单模态上进行评估,这意味着该方法无须像其他方法一样进行模态的判别,而这在实际应用中是无法实现的:无法提前知道即将到来的数据是什么模态。因此其性能是MAML算法在准确辨识模态情况下的性能上界,而且其性能在实际中无法得到。

3.MMAML:作为最近的一项研究成果,MMAML应用特征线性调制(FiLM)来识别任务的模式,然后调整元学习器参数以产生多个初始化。

本方案的方法用MIML-DA表示。

回归实验:

实验条件准备:进行multimodal few-shot regression实验,本申请人从一维函数中抽取五对输入输出数据{x

实验方法:

首先考虑了三种基线方法MAML,Multi-MAML和MMAML,这三种方法都有元网络,MMAML进一步用一个调制网络来增强元网络。首先,将按x值排序的数据点输入调制网络,生成特定任务的调制向量,用于调制元分离网络,然后,进一步调整调制后的元分离网络。

实验结果:如表一所示,表一说明了本方案的方法和其他基线方法在平均均方误差(MSE)方面的性能,其中每种情况下的最低值用黑体突出显示。结果表明,所提出的MIML-DA在所有情况下都达到了最佳性能。更具体地说,传统的 MAML在所有情况下都有最大的误差,并且加入任务身份的Multi-MAML的性能明显优于MAML,这表明在多模态任务分布下MAML会退化。由于调制网络产生的向量暗示了输入数据的模式,调制后的元学习器可以得到更好的初始化。因此,基于梯度的优化方法在这种情况下可以获得更好的性能,因此MIML-DA和MMAML的性能明显优于LSTM学习者。最后,MIML-DA的性能优于MMAML,因为跨域调制网络和元分离网络减少了训练和测试阶段之间的域偏移,从而提高了泛化能力。

表一

图像分类实验:

进行multi-modal few-shot image classification实验,这种分类任务考虑了将图像分类为N类,其中标记可用样本数为K的N类,称为N-way-shot分类;创建类似于Triantafliou等人的多模式任务,将多个广泛使用的数据集组合在一起,形成由OmniglotLake等人组成的元数据集。本申请人在元数据集上训练模型,包括两种模式(Omniglot和Mini Imagenet)、三种模式(Omniglot、 Mini Imagenet和FC100)和五种模式(所有五个数据集)。

总体结果见表二。观察到本方案提出的MIML-DA方法在几乎所有情况下都达到了最佳性能,只有一个值除外。总的来说,分类方法之间的性能比较类似于回归方法。随着模式数目的增加,MIML-DA与基线之间的性能差距越来越大,表明我们的方法能够更好地处理多模式任务分布。值得注意的是,Multi-MAML获得了很好的性能,因为每个Multi-MAML很可能会过度适应一个具有较少类的单个数据集。相反,MMAML和MIML-DA从所有数据集中学习模型。结果表明,由于调制网络的特性,MMML-DA的性能略好于MMAML和MMAML,由于跨域调制网络和MSN 的特性,MIML-DA的性能要好于MMAML和MMAML。

表二

强化学习实验:

在MuJoCo物理模拟器上验证MIML-DA在多模态元强化学习中的能力,以适应基于有限经验的新任务。考虑到图3中的三个环境,在每个时间点上对agent 进行奖励,以最小化从多模态分布中采样的到未知目标的距离。

用ProMP代替MAML作为我们的基准,此外,基线Multi-ProMP使用Vuorio 等人(2019年)提出的ProMP方法为每种模式训练一个策略。由于任务的对称分布和随机初始值,agent只接受一种模式的训练移动。同样利用了ProMP对 MMAML和MIML-DA的策略和调制网络进行了优化。

结果如表三、表四和表五所示。如所观察到的,在所有三种环境中,MIML-DA 在各种模式下的表现始终优于ProMP和MMAML。值得注意的是,由于每个多ProMP 只考虑单一模式,所以Multi-ProMP表现出良好的性能。

表三

表四

表五

上述具体实施方式,并不构成对本发明保护范围的限制。本领域技术人员应该明白的是,取决于设计要求和其他因素,可以发生各种各样的修改、组合、子组合和替代。任何在本发明的精神和原则之内所作的修改、等同替换和改进等,均应包含在本发明保护范围之内。

去获取专利,查看全文>

相似文献

  • 专利
  • 中文文献
  • 外文文献
获取专利

客服邮箱:kefu@zhangqiaokeyan.com

京公网安备:11010802029741号 ICP备案号:京ICP备15016152号-6 六维联合信息科技 (北京) 有限公司©版权所有
  • 客服微信

  • 服务号