首页> 中国专利> 基于记忆巩固机制与GAN模型的序列化任务完成方法及系统

基于记忆巩固机制与GAN模型的序列化任务完成方法及系统

摘要

本发明涉及基于记忆巩固机制与GAN模型的序列化任务完成方法,与现有技术相比解决了多任务场景下模型会出现遗忘致使无法完成序列化任务的缺陷。本发明包括以下步骤:序列化任务的获取;设定索引器并生成任务索引;利用GAN模型进行任务训练;进行伪样本的联合训练;新序列化任务的完成。本发明通过重要参数的保护和记忆回放的设计,将记忆巩固机制应用于GAN模型,使得GAN模型具备了多任务处理能力,能够保留子任务中的重要信息、遗忘非重要信息,实现了序列化任务的完成。

著录项

说明书

技术领域

本发明涉及序列化任务处理方法技术领域,具体来说是基于记忆巩固机制与GAN模型的序列化任务完成方法及系统。

背景技术

生成式对抗网络(GAN,Generative Adversarial Networks)是一种深度学习模型,其由一个生成器和一个判别器构成,通过对抗学习的方式来训练,目的是估测数据样本的潜在分布并生成新的数据样本。生成器的目的是尽量去学习真实的数据分布,而判别器的目的是尽量正确判别输入数据是来自真实数据还是来自生成器;为了取得游戏胜利,这两个游戏参与者需要不断优化,各自提高自己的生成能力和判别能力,这个学习优化过程就是寻找二者之间的一个纳什均衡。

目前,GAN广泛应用于图像和视觉、语音和语言等领域,如图像生成,图像去噪,风格转移,高分辨率用重建,语音合成等。虽然GAN的应用场景丰富,但是在现实环境中,其面临着一个严重问题—灾难性遗忘。

灾难性遗忘指人工神经网络(Artificial neural networks,ANN)学习了新的知识之后,几乎彻底遗忘掉之前习得的内容。在现实世界中,很多任务不可能一次性得到所有的训练数据,例如开放的环境、非特定的任务。这就使得ANN必须能够利用不断产生的新数据持续地,增量地学习新知识,并且不遗忘之前所学过的重要内容,以致于GAN模型无法应用于序列化任务的应用。

序列化任务的最大特点是多任务的交错执行,不同的任务随着不同的请求进行不同的执行,通常任务一完成后,跟着的是并不相同的任务二、任务三、任务四。而由于GAN模型的灾难性遗忘,使得其在学习完任务二后不再记得任务一的内容,致使序列化任务无法完成。

因此,如何设计一种能够实现序列化、多任务的完成方法已经成为急需解决的技术问题。

发明内容

本发明的目的是为了解决现有技术中多任务场景下模型会出现遗忘致使无法完成序列化任务的缺陷,提供一种基于记忆巩固机制与GAN模型的序列化任务完成方法及系统来解决上述问题。

为了实现上述目的,本发明的技术方案如下:

一种基于记忆巩固机制与GAN模型的序列化任务完成方法,包括以下步骤:

11)序列化任务的获取:获得待进行处理的序列化任务;

12)设定索引器并生成任务索引:对索引器进行设定,将序列化任务输入索引器生成任务索引;

13)利用GAN模型进行任务训练:使用GAN模型对建立任务索引号后的任务进行训练;

14)进行伪样本的联合训练:利用记忆回放方式生成任务一的伪样本,利用GAN模型对建立任务索引号后的任务二和任务一的伪样本进行训练,训练中保护任务一的参数重要性;

15)新序列化任务的完成:获取新的序列化任务,重复上述11)-14)步骤,完成新的序列化任务。

所述的对索引器进行设定中索引器为One-hot向量,其形式为:任务1(0,…,1),任务2(0,…,1,0),任务n(0,..1,..0)。

所述的利用GAN模型进行任务训练包括以下步骤:

31)将序列化任务输入GAN模型;

32)利用GAN模型的生成器对建立任务索引号后的任务一进行训练,对任务一进行重要参数的保护,GAN模型的生成器生成任务一的伪样本;

33)利用GAN模型的生成器对建立任务索引号后的任务二进行训练,对任务二进行重要参数的保护,GAN模型的生成器生成任务二的伪样本;

34)利用GAN模型的生成器对建立任务索引号后的任务n进行训练,对任务n进行重要参数的保护,GAN模型的生成器生成任务n的伪样本。

所述的进行重要参数的保护为计算GAN模型的生成器中每个参数的重要性,其包括以下步骤:

41)设定GAN模型的生成器中每个参数的衡量方式采用EWC方式或MAS方式,其表达式如下:

其中,I

I

42)若计算I

I=max(0,I

其中,I为最终的参数重要性,max指取最大值;

43)基于学习新任务时保护重要性大的参数防止其被新任务完全覆盖掉、其允许重要性小的参数的更新以便继续学习新任务的原则,将保护重要参数的方式设定为以下表达式:

其中,L为原有新任务的目标函数,L

所述的进行伪样本的联合训练包括以下步骤:

51)利用GAN模型的生成器生成的伪样本及其原样本索引,送入GAN模型,其表达式如下:

其中,S

52)将旧任务的伪样本与新任务的样本混合,其表达式如下:

S(m+1)={S

其中,S

53)用混合样本对持续学习GAN模型进行联合训练,直到新旧任务的性能都满足设定要求。

所述的序列化任务为图像连续生成任务,其具体包括以下步骤:

61)获取图像连接生成任务,并将图像连接生成任务利用索引器生成任务索引;

62)将图像连接生成任务输入GAN模型进行训练,GAN模型中的重要保护模块计算图像连接生成任务一的参数重要性,并存储下来;

63)GAN模型中的重要保护模块计算图像连接生成任务二的参数重要性,并存储下来;

64)利用记忆回放方式生成图像连接生成任务一的伪样本,GAN模型对建立任务索引号后的任务二和任务一的伪样本进行联合训练;

65)获取待处理的图像连续生成任务,利用61)至64)步骤进行新的图像连接生成。

所述的序列化任务为连续3D打印任务,其具体包括以下步骤:

71)获取连续3D打印任务,其中任务一为3D打印椅子、任务二为3D打印花瓶;

72)利用索引器对任务一的椅子图片作为训练图片建立索引,同时为3D打印花瓶作为图片建立索引;

73)将任务一输入GAN模型进行训练,GAN模型中的重要保护模块计算生成椅子的参数重要性,并存储下来;

74)利用记忆回放方式生成任务一的椅子图片的伪样本,GAN模型对3D打印花瓶的任务二和任务一的伪样本进行联合训练;

75)获取待处理的连续3D打印任务,根据获得信息是椅子或花瓶,进行椅子或花瓶的打印。

所述的序列化任务为连续音色转换任务,其包括以下步骤:

81)获取连续音色转换任务,其中任务一为女声源声、任务二为男声源声;

82)利用索引器对任务一的女声源声建立索引,同时为任务二的男声源声建立索引;

83)将任务一的女声源声输入GAN模型进行训练,GAN模型中的重要保护模块计算生成女声源声的参数重要性,并存储下来;

84)利用记忆回放方式生成任务一的女声源声的伪样本,GAN模型对任务二的男声源声和任务一的伪样本进行联合训练;

85)获取待处理的连续音色转换任务,根据源声将其转换为目标声1或目标声2。

所述的序列化任务为智能小车工作任务,其包括以下步骤:

91)获取智能小车工作任务,其中任务一为识别分拣工作一、任务二为识别分拣工作二;

92)利用索引器对任务一的识别分拣工作一建立索引,同时为任务二的识别分拣工作二建立索引;

93)将任务一的识别分拣工作一输入GAN模型进行训练,GAN模型中的重要保护模块计算生成识别分拣工作一的参数重要性,并存储下来;

94)利用记忆回放方式生成任务一的识别分拣工作一的伪样本,GAN模型对任务二的识别分拣工作二和任务一的伪样本进行联合训练;

95)获取待处理的智能小车工作任务,根据其任务将其执行识别分拣工作一或识别分拣工作二。

基于记忆巩固机制与GAN模型的序列化任务完成方法的系统,包括序列化任务输入模块、索引器、GAN模型和序列化任务输出模块,所述的GAN模型包括生成器和判别器,所述的生成器中包括记忆回放模块和重要参数保护模块,重要参数保护模块用于保护重要参数防止其在后续学习中补覆盖,记忆回放模块通过混合新任务和部分旧任务样本对GAN持续学习系统进行联合训练,从而巩固旧任务,索引器用于随时间次序为任务创建索引号;所述的序列化任务输入模块的输出端与索引器的输入端相连,索引器的输出端与GAN模型的输入端相连,GAN模型的输出端与序列化任务输出模块的输入端连。

有益效果

本发明的基于记忆巩固机制与GAN模型的序列化任务完成方法及系统,与现有技术相比通过重要参数的保护和记忆回放的设计,将记忆巩固机制应用于GAN模型,使得GAN模型具备了多任务处理能力,能够保留子任务中的重要信息、遗忘非重要信息,实现了序列化任务的完成。同时,结合记忆巩固机制后的GAN模型更加贴合人类大脑记忆模型,更高效地解决复杂地连续学习问题。

附图说明

图1为本发明的方法顺序图;

图2为本发明第一种实施方式中普通GAN和本发明所述GAN网络在连续生成MNIST数字手写体的效果对比图;

图3为本发明第二种实施方式中普通GAN和本发明所述GAN网络连续3D打印的效果对比图;

图4为本发明第三种实施方式中基于本发明所述GAN网络的多音色变声软件界面图;

图5为图4所述多音色变声软件工作时音色转换波谱图。

具体实施方式

为使对本发明的结构特征及所达成的功效有更进一步的了解与认识,用以较佳的实施例及附图配合详细的说明,说明如下:

在实际研究中发现,造成ANN灾难性遗忘的根本原因是新数据会修改与历史知识相关的重要神经元的参数。现有的灾难性遗忘问题解决方法大致可以分为两派:参数派与结构派。其中,参数派主张对深度学习进行修补,结构派主张提出新的人工神经网络模型。大致有四种方法:1)利用新数据训练的同时,不断用包含历史数据相关的信息刺激神经元,形成一种竞争,从而使历史知识相关的重要神经元的参数尽可能少的受影响,同时也保证了新知识能够被学习;2)在开始训练新数据前,利用旧网络对新数据进行预测得到虚拟的训练数据,可以看作是旧网络的一个回忆,目标函数中包含新旧网络的参数约束,每训练一个新数据,利用所有的虚拟数据约束旧参数,抑制遗忘;3)从另一个角度来约束参数的变化,不同的任务对应不同参数的概率分布,如果能找到两个分布重叠的部分,并将参数约束到这个区域,那么这一参数不就可以对这些任务都有效;4)保留所有的历史数据,并用其反复训练网络来防止遗忘。目前上述方法均在一定程度上缓解了灾难性遗忘,但是效果均有限,究其原因是上述机制并没有较高程度地复现人类记忆机制并应用于ANN上。

本发明在原始GAN模型的基础上,增加了索引器,重要参数保护模块和记忆重放模块。当序列化的任务到来时,这里以两个任务为例,任务1:生成猫的图像,任务2:生成狗的图像。索引器为每一个任务建立唯一的索引。接下来,GAN持续学习系统学习任务1。学习完成后,重要参数保护模块开始工作,其由两个步骤组成。首先当学习完成任务1后,其计算生成器关于任务1的参数重要性,这里参数重要性指神经网络参数包含当前任务的重要信息量,重要性大的参数包含了较多任务的重要信息。接下来在学习任务2时,重要参数保护模块适当抑制具有较高重要性的参数的变化以达到保护重要信息的目的。同时,在学习任务2时记忆重放模块也开始工作,其首先再次生成部分任务1的样本数据,并将该数据与任务2的数据混合,最后对GAN持续学习系统进行联合训练,直到其在任务1和任务2上均满足所需性能。可以看出,重要参数保护模块通过模仿突触可塑性机制对包含重要信息的参数进行保护,以防止这些参数在后续学习新任务的阶段中被代替。记忆回放模块模仿海马区记忆巩固机制,通过用部分旧任务的反复刺激神经网络防止其被遗忘。

本发明所述的基于记忆巩固机制与GAN模型的序列化任务完成方法的系统,包括序列化任务输入模块、索引器、GAN模型和序列化任务输出模块,所述的GAN模型包括生成器和判别器,所述的生成器中包括记忆回放模块和重要参数保护模块,重要参数保护模块用于保护重要参数防止其在后续学习中补覆盖,记忆回放模块通过混合新任务和部分旧任务样本对GAN持续学习系统进行联合训练,从而巩固旧任务,索引器用于随时间次序为任务创建索引号;所述的序列化任务输入模块的输出端与索引器的输入端相连,索引器的输出端与GAN模型的输入端相连,GAN模型的输出端与序列化任务输出模块的输入端连。

其中,索引器为随时间依次到来的任务创建索引号,每个任务的索引号具有唯一性。索引器的产生主要基于以下三个原因。首先,在现实场景中任务的到来是具有时序性的,因此任务对应的索引号可以用来区分不同时段的任务。其次,任务具有时效性,即过去任务的数据往往不可再次得到,而新任务的数据很快会覆盖掉过去的数据,但是任务的索引号是可以保留的,因此索引器的另一个功能就是保留过去任务的身份信息。最后,当过去任务需要被再次执行时,可以将任务的索引号送入GAN持续学习系统,由于GAN持续学习系统已经具备持续学习的能力,所以当再次遇到过去某个任务的索引号时,便会激发系统关于该任务的记忆,从而达到实现复现该任务的目的。索引的形式多种多样,只要能区分不同任务即可。最简单的形式是One-hot向量,如任务1(0,…,1),任务2(0,…,1,0),任务n(0,..1,..0)。

重要参数保护模块是模拟人类突触可塑性机制创建的。人类大脑海马区中突触的强度在接受新信息时一部分会被增强,一部分会被减弱。被增强的突触可以强化信息的存储,被减弱的突触可以为后续信息的学习提供空间。重要参数保护模块模拟突触可塑性计算神经网络中参数的重要性,并保护重要参数防止其在后续学习中被覆盖。具体工作流程如下:

步骤1:当学习完一个任务后,其首先计算GAN的生成器网络中每个参数的重要性,参数重要性的衡量方式采用弹性权重巩固算法(Elastic Weight Consolidation,EWC)或记忆感知突触算法(Memory Aware Synapses,MAS),分别如公式(1)和(2)所示:

其中,I

注意,当I

I=max(0,I

其中,I为最终的参数重要性,max指取最大值。

步骤2:在得到参数重要性后,重要参数保护模块需要在学习新任务时保护重要性大的参数防止其被新任务完全覆盖掉,同时其允许重要性小的参数的更新以便继续学习新任务。其保护重要参数的方式如公式(4),在原有新任务的目标函数的基础上,根据参数的重要性,对重要参数的值的改变施加一定的惩罚。

其中,L为原有新任务的目标函数,L

记忆回放模块模拟海马区经验回放机制。海马区经验回放是指,人类在睡眠期间,海马区会回放过的的知识,激活相应的突触来增强记忆。记忆回放模块通过混合新任务和部分旧任务的样本对GAN持续学习系统进行联合训练,从而巩固旧任务。其具体流程:

步骤1:由于旧任务的真实数据已经不可得,首先利用生成器生成部分旧任务的样本(伪样本),即将过去任务的索引送入GAN持续学习系统,如公式(5)所示:

其中,S

步骤2:将旧任务的伪样本与新任务的样本混合,如公式(6),之后用混合样本对持续学习GAN系统进行联合训练,直到新旧任务的性能都满足要求,如公式(6)所示:

S(m+1)={S

其中,S

如图1所示,本发明所述的一种基于记忆巩固机制与GAN模型的序列化任务完成方法,包括以下步骤:

第一步,序列化任务的获取:获得待进行处理的序列化任务。

第二步,设定索引器并生成任务索引:对索引器进行设定,将序列化任务输入索引器生成任务索引。对索引器进行设定中索引器为One-hot向量,其形式为:任务1(0,…,1),任务2(0,…,1,0),任务n(0,..1,..0)。

第三步,利用GAN模型进行任务训练:使用GAN模型对建立任务索引号后的任务进行训练。其具体步骤如下:

(1)将序列化任务输入GAN模型。

(2)利用GAN模型的生成器对建立任务索引号后的任务一进行训练,对任务一进行重要参数的保护,GAN模型的生成器生成任务一的伪样本。

(3)利用GAN模型的生成器对建立任务索引号后的任务二进行训练,对任务二进行重要参数的保护,GAN模型的生成器生成任务二的伪样本。

(4)利用GAN模型的生成器对建立任务索引号后的任务n进行训练,对任务n进行重要参数的保护,GAN模型的生成器生成任务n的伪样本。

以上进行重要参数的保护为计算GAN模型的生成器中每个参数的重要性,其包括以下步骤:

首先,设定GAN模型的生成器中每个参数的衡量方式采用EWC方式或MAS方式,其表达式如下:

其中,I

I

其次,若计算I

I=max(0,I

其中,I为最终的参数重要性,max指取最大值;

最后,基于学习新任务时保护重要性大的参数防止其被新任务完全覆盖掉、其允许重要性小的参数的更新以便继续学习新任务的原则,将保护重要参数的方式设定为以下表达式:

其中,L为原有新任务的目标函数,L

第四步,进行伪样本的联合训练:利用记忆回放方式生成任务一的伪样本,利用GAN模型对建立任务索引号后的任务二和任务一的伪样本进行训练,训练中保护任务一的参数重要性。其具体步骤如下:

(1)利用GAN模型的生成器生成的伪样本及其原样本索引,送入GAN模型,其表达式如下:

其中,S

(2)将旧任务的伪样本与新任务的样本混合,其表达式如下:

S(m+1)={S

其中,S

(3)用混合样本对持续学习GAN模型进行联合训练,直到新旧任务的性能都满足设定要求。

第五步,新序列化任务的完成:获取新的序列化任务,重复上述第一至第五步骤,完成新的序列化任务。

在此,作为本发明的第一种实施方式,其为图像连续生成任务时,神经网络在训练过程中往往面临数据量不足的情况,数据量不足会导致网络无法充分学习到数据的特征,造成模型过拟合问题。解决数据量不足的一个方法是扩充数据集,其中用GAN生成数据集是一种有效而简便的方法。我们将GAN持续学习系统应用于图像生成,检测其是否具有连续生成多类别图像的能力。实验中,采用MNIST数据集,MNIST数据集包含手写体数字0,1,2,3,4,5,6,7,8和9。我们将其分为两组,第一组是0-4,第二组是5-9。任务1是生成第一组数字,任务2是生成第二组数字。具体步骤如下:

步骤1:索引器为第一组数字0-4的训练图片创建索引号;

步骤2:将这些图片送入持续学习GAN系统进行生成训练;

步骤3:重要参数保护模块计算生成第一组数字0-4的参数重要性,并存储下来;

步骤4:索引器为第二组数字5-9的训练图片创建索引号;

步骤5:记忆回放模块生成部分第一组数字0-4的伪样本;

步骤6:记忆回放模块混合第一组数字0-4的伪样本和第二组数字5-9的样本;

步骤7:用混合的样本训练持续学习GAN系统,训练过程保护步骤三得到的重要参数。

结果如图2所示,可以看出,普通的GAN发生了灾难性遗忘,即在学习完第二个任务后忘记了如何生成第一组数字。而持续学习GAN系统可以在学习任务2后,依旧保留生成任务1的能力。

作为本发明的第二种实施方式,序列化任务为连续3D打印任务。3D打印即快速成型技术的一种,又称增材制造,它是一种以数字模型文件为基础,运用粉末状金属或塑料等可粘合材料,通过逐层打印的方式来构造物体的技术。我们将持续学习GAN模型应用于3D打印,检测其是否具备连续3D打印的能力。任务1是3D打印椅子,任务2是3D打印花瓶。具体步骤如下:

步骤1:索引器为椅子的训练图片创建索引号;

步骤2:将这些图片送入持续学习GAN系统进行生成训练;

步骤3:重要参数保护模块计算生成椅子的参数重要性,并存储下来;

步骤4:索引器为花瓶的训练图片创建索引号;

步骤5:记忆回放模块生成部分椅子的伪样本;

步骤6:记忆回放模块混合椅子的伪样本和花瓶的样本;

步骤7:用混合的样本训练持续学习GAN系统,训练过程保护步骤三得到的重要参数。

实验结果如图3所示,普通GAN网络发生了灾难性遗忘,其在学会3D打印花瓶后,就忘记了如何3D打印椅子。而持续学习GAN系统可以在学习3D打印椅子的基础上,增量地学习3D打印花瓶。

作为本发明的第三种实施方式,序列化任务为连续音色转换任务,语音转换是一种在保留语言信息的同时转换指定话语的语言信息的技术。语音转换可以在很多地方得到应用,如文本到语音系统的说话人身份(男女,老少等)的转换等。目前不具备记忆功能的语音转换系统在学习一个音色转换后,往往会覆盖掉之前的音色转换信息,这样语音转换系统就无法增量地学习。我们将持续学习GAN系统应用到现在的语音转换系统,起名为多音色变换软件平台,如图4和图5所示。在多音色变换软件平台中,一种女(男)声可以连续地转换为多种不同类型的男(女)声。具体步骤如下:

步骤1:索引器为女声(源声)和第一种男声(目标声1)创建索引号;

步骤2:将源声和目标声1送入持续学习GAN系统进行音色转换训练;

步骤3:重要参数保护模块计算原声到目标声1的参数重要性,并存储下来;

步骤4:索引器为第二种男声(目标声2)创建索引号;

步骤5:记忆回放模块生成源声和目标声1;

步骤6:记忆回放模块混合生成源声和目标声1和目标声2;

步骤7:用混合的样本训练持续学习GAN系统学会将源声转换为目标声1和目标声2,训练过程保护步骤三得到的重要参数。

多音色变换软件平台实现了增量学习多种不同类型的音色转换。

作为本发明的第四种实施方式,序列化任务为智能小车工作任务,我们将持续学习GAN系统应用于物流小车的识别和分拣工作。不具备持续学习能力的物流小车无法在保持在当前物体的识别和分拣的基础上,随着种类增多的物体进行增量识别和分拣,极大地限制了其在现实场景中的应用。基于GAN持续学习系统的智能物流小车,可以增量地学习新类别的物体的识别和分拣工作。具体步骤如下:

步骤1:索引器为物流小车的识别和分拣工作1创建索引号;

步骤2:将工作1的训练数据送入持续学习GAN系统进行训练;

步骤3:重要参数保护模块计算工作1的参数重要性,并存储下来;

步骤4:索引器为物流小车的识别和分拣工作2创建索引号;

步骤5:记忆回放模块生成部分工作1的伪样本;

步骤6:记忆回放模块混合工作1的伪样本和工作2的样本;

步骤7:用混合的样本训练持续学习GAN系统,训练过程保护步骤3得到的重要参数。

以上显示和描述了本发明的基本原理、主要特征和本发明的优点。本行业的技术人员应该了解,本发明不受上述实施例的限制,上述实施例和说明书中描述的只是本发明的原理,在不脱离本发明精神和范围的前提下本发明还会有各种变化和改进,这些变化和改进都落入要求保护的本发明的范围内。本发明要求的保护范围由所附的权利要求书及其等同物界定。

去获取专利,查看全文>

相似文献

  • 专利
  • 中文文献
  • 外文文献
获取专利

客服邮箱:kefu@zhangqiaokeyan.com

京公网安备:11010802029741号 ICP备案号:京ICP备15016152号-6 六维联合信息科技 (北京) 有限公司©版权所有
  • 客服微信

  • 服务号