技术领域
本发明涉及分布式机器学习技术领域,具体而言涉及一种基于版本控制的分布式机器学习模型更新方法。
背景技术
机器学习技术在多种任务场景中取得了广泛的成功,但随着数据的膨胀和任务复杂性的增加,将海量数据并行分配到多个计算节点的分布式机器学习成为了高效且可行的解决方案。近年来,随着硬件设备和物联网技术的不断发展,越来越多的边缘设备可以参与分布式机器学习的计算,虽然更多设备的参与提高了系统整体的吞吐量和算力,但同时也增加了系统异构性从而对不同设备间的调度提出了挑战。
分布式机器学习具有两个性能指标,一方面,机器学习场景希望维护一定的模型更新一致性既统计效率(STE)从而降低收敛所需的迭代轮数;另一方面,分布式系统希望维护一定的系统吞吐量既硬件效率(HWE)从而降低每轮迭代计算所消耗的时间。可以看出系统的整体性能指标---收敛所消耗的时间收到这两个指标影响。
在分布式的场景下,传统的调度方法如完全同步(BSP)通过严格同步限制维护了良好的统计效率却限制了硬件效率;而完全异步(ASP)通过放松限制维护了最好的硬件效率却限制了统计效率。不同于传统计算,机器学习的支撑算法如随机梯度下降法(SGD)等往往具有一定的鲁棒性,不要求所采用的机器学习模型具有严格一致性。基于此,随后的改进调度算法如有限异步(SSP)和软同步(SP)均在寻找硬件效率和统计效率的权衡。然而,有限异步所依赖的设备性能近似假设以及软同步的静态同步参数设置使得它们难以应对复杂的甚至是动态的高系统异构性场景。
发明内容
本发明针对现有技术中的不足,提供一种基于版本控制的分布式机器学习模型更新方法,采用版本控制动态评估当前分布式系统的硬件效率和统计效率,并且采用在线强化学习方法动态追踪两者的最佳权衡以实现最优整体性能。
为实现上述目的,本发明采用以下技术方案:
一种基于版本控制的分布式机器学习模型更新方法,所述更新方法包括以下步骤:
S10,本地计算:工作节点基于本地参数以及其版本进行梯度计算和版本传递,随后发送包含版本信息的push请求向服务器节点通信;
S20,梯度通信:参数服务器将根据push请求的版本信息判断所属工作节点的通信合法性,并进行相应的操作;
S30,全局更新:参数服务器收集梯度,并根据强化学习方法所得到的控制信息控制全局参数更新和通信;
S40,参数通信:参数服务器根据步骤S20和步骤S30的控制结果发送当前最新全局参数以及版本信息,接收到的工作节点进行本地参数以及版本的更新;
重复以上过程直到满足停止条件:分布式系统的收敛所消耗的时间小于预设时间阈值。
为优化上述技术方案,采取的具体措施还包括:
进一步地,步骤S10中,所述本地计算的过程包括如下步骤:
S11,工作节点将根据本地参数w和版本V(w)计算梯度以及梯度的版本:假设该工作节点m所使用的批大小为n,将参数的版本传递给工作节点V(m)←V(w),随后根据本地数据(x
S12,工作节点m将梯度的版本信息V(m)压入push请求中,请求与参数服务器通信梯度和参数。
进一步地,步骤S20中,所述梯度通信包括如下步骤:
S21,参数服务器根据收集到的push请求中的版本信息,结合工作节点的梯度版本与参数服务器当前的全局参数版本差距进行合法性检查:若差距大于第一差距阈值则判定为曾经离线工作节点,发送丢弃回应;若差距小于第二差距阈值则判定为较快工作节点,发送跳过回应;否则,判定为正常工作节点,发送更新回应;所述第一差距阈值大于第二差距阈值;
S22,参数服务器将与步骤S21中判断为正常的工作节点进行参数通信以收集工作节点梯度。
进一步地,步骤S30中,所述全局更新包括如下步骤:
S31,参数服务器根据步骤S20的判断结果进行控制信息更新,根据更新后的控制信息和当前的控制阈值控制参数的更新和通信;
S32,参数服务器采用强化学习的方法生成自适应的控制阈值并更新当前控制阈值τ。
进一步地,步骤S31中,所述参数服务器根据步骤S20的判断结果进行控制信息更新,根据更新后的控制信息和当前的控制阈值控制参数的更新和通信的过程包括如下步骤:
S311,参数服务器根据步骤S20的判断结果进行版本延迟信息
S312,参数服务器根据步骤S311中的版本延迟信息
进一步地,更新过程中所采用的数据结构和操作为
进一步地,步骤S32中,所述参数服务器采用强化学习的方法生成自适应的控制阈值并更新当前控制阈值τ的过程包括如下步骤:
S321,参数服务器定期根据当前的控制阈值τ、收集到的版本延迟信息
S322,参数服务器采用强化学习算法产生新的动作,根据新的动作转移至新的状态,并根据该状态更新当前控制阈值τ用作后续的控制。
进一步地,步骤S321中,所述奖赏函数采用依赖于当前训练阶段的完全在线函数和采用离线数据驱动通过训练神经网络所得到拟合函数中的任意一种。
进一步地,步骤S40中,所述参数通信的过程包括如下步骤:
对工作节点接收到的控制信息进行判断:
若工作节点接收到参数服务器发送的丢弃回应则丢弃本地梯度并与参数服务器进行参数通信:
本发明的有益效果是:
(1)本发明的一种基于版本控制的机器学习模型更新方法,相较于其他方法首次提出了追求最佳的硬件效率和统计效率的权衡概念,并且通过实时地追踪结合强化学习以实现这一权衡。
(2)从性能上来说,本方法通过强化学习解决了静态同步参数这一瓶颈从而动态自适应地实现最佳性能。
(3)同时在鲁棒性方面,得益于本方法采用的离线节点检测和较快节点跳过机制,本方法无需依靠任何性能和问题假设可以适配绝大多数异构甚至是动态异构的场景并且发挥最佳的系统性能。
附图说明
图1是本发明的基于版本控制的分布式机器学习模型更新方法的整体流程图。
图2是本发明的基于版本控制的模型更新流程图。
图3是本发明的基于版本控制的系统算法描述图。
图4是本发明的基于强化学习的自适应阈值调整算法描述图。
具体实施方式
现在结合附图对本发明作进一步详细的说明。
需要注意的是,发明中所引用的如“上”、“下”、“左”、“右”、“前”、“后”等的用语,亦仅为便于叙述的明了,而非用以限定本发明可实施的范围,其相对关系的改变或调整,在无实质变更技术内容下,当亦视为本发明可实施的范畴。
结合图1,本发明提及一种基于版本控制的分布式机器学习模型更新方法,所述更新方法包括以下步骤:
S10,本地计算:工作节点基于本地参数以及其版本进行梯度计算和版本传递,随后发送包含版本信息的push请求向服务器节点通信。
S20,梯度通信:参数服务器将根据push请求的版本信息判断所属工作节点的通信合法性,并进行相应的操作。
S30,全局更新:参数服务器收集梯度,并根据强化学习方法所得到的控制信息控制全局参数更新和通信。
S40,参数通信:参数服务器根据步骤S20和步骤S30的控制结果发送当前最新全局参数以及版本信息,接收到的工作节点进行本地参数以及版本的更新。
重复以上过程直到满足停止条件:分布式系统的收敛所消耗的时间小于预设时间阈值。
本发明是一种基于版本控制的机器学习模型更新方法,如图1所示,主要包括如下步骤:工作节点首先根据本地参数和数据进行S10本地计算得到梯度和版本信息,随后发送push请求开始通信;参数服务器在接收到push请求后开始S20梯度通信通过检测梯度的版本合法性来决定是否收集该梯度;随后开始S30全局更新,通过控制信息控制后续的全局参数更新和通信并且使用强化学习自适应的调整控制阈值;最后根据S30结果进行S40参数通信;重复上述流程直到满足问题收敛条件。
其中具体的版本控制流程如图2所示,在S12工作节点发送完push请求后参数服务器将进入S21开始梯度的合法性检查。若该工作节点的梯度版本与参数服务器当前的全局参数版本差距过大则代表该节点为离线节点。由于离线节点版本过旧,本方法将发送丢弃回复,更新控制信息后发送最新全局参数和版本使离线节点开始新一轮计算。若该工作的梯度版本与参数服务器当前的全局参数版本差距很小则表示该工作节点为较快的工作节点,可以根据当前带宽情况可选择地接收较快工作节点地梯度,同时该较快工作节点无需更新本地参数直接开始新一轮计算。若该工作节点梯度版本合法,则参数服务器发送更新回复,收集该工作节点梯度并正常进入后续流程。然后参数服务器进入S30全局更新阶段,根据控制信息和控制阈值判断是否可以进行全局更新,若该工作节点参与更新后导致参数版本延迟超过延迟的阈值则不进行参数更新并将该工作节点放入等待队列中。若满足阈值则首先进入S32通过强化学习方法更新阈值,然后进入正常更新阶段S42。
其中图3为本发明所实现的一个基于版本控制的系统的算法描述,其中工作节点端在进行完S10本地计算后进入等待接收参数服务器回应阶段,根据接收到的不同回应进行不同的动作。接收到跳过回应则根据当前带宽情况可选择的进行本地梯度的发送,若收到跳过回应且选择发送则需要将梯度置为0,若选择不发送则采用增量式梯度更新
而表1为本发明所实现的一个基于版本控制的系统的关键数据结构和步骤表。
表1
本发明将参数和版本放入同一个数据结构(w,V(w))中用以方便后续计算,其中参数服务器负责全局参数和版本更新:
结合前人的收敛性证明,通过定义P工作节点的个数,η为学习率,L为李普希茨条件,γ为梯度的方差上界以及N
通过设置合适的学习率:
就可以得到本发明所提出的算法1收敛性证明:
图4为本发明所采用的自适应控制阈值调整算法,本方法将控制阈值τ与状态空间进行绑定,并定义增加、减少和不改变阈值τ为动作空间,将追踪最佳的控制阈值转化为强化学习追求最大化统计效率加硬件效率这一问题。
以上仅是本发明的优选实施方式,本发明的保护范围并不仅局限于上述实施例,凡属于本发明思路下的技术方案均属于本发明的保护范围。应当指出,对于本技术领域的普通技术人员来说,在不脱离本发明原理前提下的若干改进和润饰,应视为本发明的保护范围。
机译: 机器学习模型更新系统,边缘设备,机器学习模型更新方法和程序
机译: 基于分布式系统和机器学习模型的HEVC设备方法和系统利用块链网络
机译: 在基于云的计算环境中使用分布式LED技术为智能合约实施机器学习模型的系统,方法和设备