技术领域
本发明涉及信息技术领域,特别是涉及一种基于神经网络架构搜索的联邦学习方法及系统。
背景技术
如今的集中式网络训练已经能够执行各种数据挖掘任务,但充足的数据集是保证模型可靠性的前提。考虑到隐私保护、通信和存储成本、知识产权保护、监管限制和法律约束等问题,在一些特殊的领域,比如:金融、医疗和政务等,将各方数据收集在一个服务器上进行建模和挖掘十分困难。联邦学习是解决此类问题的一个通用方法,它是近年来提出的一种分布式机器学习范式。可以在各个参与方不透露自己数据信息的前提下,通过同态加密等加密算法,将各个参与方的梯度结果汇总到一个服务器上来更新训练一个全局模型。
在每一轮的训练过程中,各个参与方用自己本地的数据来训练此时的全局模型,并且得到一个梯度信息。通过同态加密等方式对梯度信息进行加密后传输给服务器端。服务器端根据相应的聚合算法(如FedAvg算法)将各个梯度聚合并更新全局参数完成本轮训练。
在现有的联邦学习模型中,人们使用预设的模型作为初始的全局模型进行训练。这种使用预先设定的模型进行联邦训练的方式会降低模型在实际数据分布场景下的鲁棒性降低,因为多数情况下,数据集在参与方中都是非独立同分布的,而且开发人员并不知道数据分布的具体情况。当然,开发人员可以选择不同的模型结构进行尝试,来的得出最合适的模型以增强其鲁棒性,但是这样做会耗费更高的成本,加重计算负担,存在一定的缺陷。
发明内容
为解决上述技术问题,本发明的目的在于,提供一种基于神经网络架构搜索的联邦学习方法及系统,实现联邦学习计算的简化,并提升联邦学习的鲁棒性以及降低模型训练的成本。
为达到上述目的,本发明采用如下技术方案:
本发明提供一种基于神经网络架构搜索的联邦学习方法,通过迭代方法直至满足设定的停止条件,包括如下步骤:
S100、中央服务器将搜索空间中神经网络的权重w和架构α发送给K个参与方设备,并且发送一个随机生成的私有密钥
S200、各个参与方设备基于所获得的w和α以及本地数据集进行训练,得到训练集和验证集的损失值以及w和α的梯度,并根据梯度下降的方法更新网络的架构和权重;
S300、每个参与方设备将各自w和α进行加密,加密所用的密钥包含中央服务器发送的初始密码以及所有之前参与方的w和α结果,并通过链式的方法传递给下一个参与方,作为下一个参与方加密密钥;
S400、最后一个参与方设备将本轮通信中所有参与方设备的加密w和α发送给中央服务器后,中央服务器用自己的密钥进行解密得出聚合的w和α结果,且基于聚合结果更新全局神经网络w和α,并将新的w和α值发送给各个参与方设备继续更新。
进一步的,在步骤S100中,所述搜索空间为连续可微的,且所述搜索空间通过梯度下降的混合运算搜索空间作为初始全局模型。
进一步的,在步骤S200中,本地数据集包括训练集、测试集以及验证集。
进一步的,在步骤S200中,网络权重w以及架构α的计算公式为:
其中
进一步的,在步骤S300中,参与方设备在通信过程中,中央服务器将所有参与方设备都发送w和α,中央服务器生成的密钥指向链式中第一个参与方设备发送,且通过最后一个参与方设备发送数据给中央服务器。
进一步的,在步骤S400中,最后一个参与方设备发送的数据为使用了
其中
进一步的,在步骤S400中,中央服务器通过
w
一种基于神经网络架构搜索的联邦学习系统,包括中央服务器以及若干个参与方设备,且所述参与方设备保存有本地数据集。
相比于现有技术,本发明具有以下优点:
1)根据各个参与方的数据异质性(数据非独立同分布)的具体特点搜索到最适合数据分布的神经网络作为训练的初始全局模型,保证了模型训练的精度,增强了联邦学习框架的鲁棒性。
2)用分布式的思想来进行神经网络搜索,用DARTS连续可微的搜索空间进行梯度下降来完成这个目标,充分利用了各个参与方节点的计算资源,提高了最优神经网络获取的速度。
3)通过链式加密的方式,在协同训练的过程中,任意参与方设备都无法获知其他参与方的信息值,保证了所有参与方的隐私数据,同时,中央服务器最终基于自己的密钥可以解密获取到精确的w和α,保证了模型训练的精度。
附图说明
图1为本发明一种基于神经网络架构搜索的联邦学习方法一种实施例的流程图;
图2为本发明一种基于神经网络架构搜索的联邦学习方法中链式传输加密以及搜索空间梯度下降更新的示意图;
图3为本发明一种基于神经网络架构搜索的联邦学习方法中的搜索空间以及最终的单元结构搜索结果示意图。
具体实施方式
下面将结合示意图对本发明的一种基于神经网络架构搜索的联邦学习方法及系统进行更详细的描述,其中表示了本发明的优选实施例,应该理解本领域技术人员可以修改在此描述的本发明,而仍然实现本发明的有利效果。因此,下列描述应当被理解为对于本领域技术人员的广泛知道,而并不作为对本发明的限制。
如图1所示,本发明实施例提出了一种基于神经网络架构搜索的联邦学习方法,具体地对每一步骤进行详细说明。
步骤S100:中央服务器将搜索空间中神经网络的权重w和架构α发送给K个参与方设备,并且发送一个随机生成的私有密钥
具体的,中央服务器发送给各个参与方w和α,为本轮搜索的权重和架构,初始的搜索空间定义遵循DARTS的定义,本发明在两个共享的卷积单元中进行搜索,然后将其构建为一个完整的模型架构。在单元内部,将两个节点之间的分类候选操作(例如卷积、最大池化、全连接)放宽到一个连续的搜索空间。
同时,参考图2,中央服务器是神经网络构架搜索的关键,整个迭代中单元的搜索过程可以抽象为三个步骤:
(1)首先定义搜索空间:对单元中各个模块边的操作最初是未知的,通过在每一条边上混合放置候选操作来连续松弛搜索空间,在图中使用多条边来表示多种操作;
(2)每一个参与方通过求解双层优化问题,实现混合概率和网络权值的联合优化;
(3)从所学的混合概率中归纳出最终的结构。
另外,参考图3,搜索空间为连续可微的,且所述搜索空间通过梯度下降的混合运算搜索空间作为初始全局模型。根据DARTS,定义允许使用梯度下降的混合运算搜索空间,在完成了一系列本地搜索和聚合之后,搜索出的单元结构。
混合运算操作包含以下操作:3×3和5×5可分卷积、3×3和5×5可分扩张卷积、3×3最大池化、3×3平均池化等。
在本例中最后搜索得出的卷积单元由7个节点组成,其中输出节点定义为所有中间节点的纵深级联,通过将多个单元堆叠在一起形成网络。将k单元的第一和第二节点分别设置为k-2和k-1单元的输出,并且根据需要插入1×1卷积。
步骤S200,各个参与方设备基于所获得的w和α以及本地数据集进行训练,得到训练集和验证集的损失值以及w和α的梯度,并根据梯度下降的方法更新网络的架构和权重。
具体的,各个参与方设备基于所获得网络和本地数据集进行训练,得到训练集和验证集的损失值以及w和α的梯度值,包括
其中
值得注意的是,由于本发明所使用的搜索空间是可微的,所以搜索神经网络架构可以通过梯度下降的方式。
步骤S300中,每个参与方设备将各自w和α进行加密,加密所用的密钥包含中央服务器发送的初始密码以及所有之前参与方的w和α结果,并通过链式的方法传递给下一个参与方,作为下一个参与方加密密钥。
具体的,第一个参与方设备在经过本地搜索得到新的w和α后用
由于之后的参与方设备操作相似,这里就用第i个参与方设备作为一个例子进行说明:
第i个参与方设备收到第i-1个参与方设备发送
最后一个设备k按照上述方法更新完加密的w和α中后,发送给中央服务器,此时发送的数据为使用了
其中
此步骤中的加密操作有两个有益之处:一是保证了任何一个参与方的隐私数据不会被泄露;二是第i个参与方所用的密钥就包含了前i-1个参与方的搜索结果,这样只需最后一个参与方发送最终结果给中央服务器,就能完成本轮聚合,而无需每一个参与方设备都向中央服务器发送数据。
步骤S400:最后一个参与方设备将本轮通信中所有参与方设备的加密w和α发送给中央服务器后,中央服务器用自己的密钥进行解密得出聚合的w和α结果,且基于聚合结果更新全局神经网络w和α,并将新的w和α值发送给各个参与方设备继续更新。
具体的,通过最先生成的密钥
w
重复步骤S400直至达到迭代结束条件。
如图2所示,本发明还提供一种基于神经网络架构搜索的联邦学习系统,包括中央服务器以及若干个参与方设备,且所述参与方设备保存有本地数据集。
显然,本领域的技术人员可以对本发明进行各种改动和变型而不脱离本发明的精神和范围。这样,倘若本发明的这些修改和变型属于本发明权利要求及其等同技术的范围之内,则本发明也意图包含这些改动和变型在内。
机译: 神经结构搜索系统,用于生成神经网络架构
机译: 神经网络架构搜索系统和方法,以及计算机可读记录介质
机译: 神经网络架构搜索方法,装置和系统