首页> 中国专利> 通信高效联合学习

通信高效联合学习

摘要

本公开涉及通信高效联合学习。本公开提供一种用于在诸如例如联合学习框架的机器学习框架内传输模型更新的有效率通信技术,该联合学习框架中在大量客户端上分布的训练数据上训练高品质集中化模型,每个客户端具有不可靠的网络连接和低计算能力。在一个示例联合学习设置下,在多轮中的每一轮中,每个客户端独立地基于其本地数据来更新模型并且将更新的模型传送回到服务器,其中所有客户端侧更新被用来更新全局模型。本公开提供降低通信成本的系统和方法。特别地,本公开提供至少:结构化更新方案,其中模型更新被约束为小并且为概略的更新方案,其中,模型更新在发送到服务器之前被压缩。

著录项

说明书

分案说明

本申请属于申请日为2017年9月25日的中国发明专利申请201710874932.2的分案申请。

技术领域

本公开大体涉及机器学习。更具体地,本公开涉及通信高效联合学习。

背景技术

随着数据集越来越大,模型越来越复杂,机器学习越来越多地需要在多台机器上分布模型参数的优化。现有机器学习算法通常仅适用于受控环境(诸如数据中心),其中数据在机器间适当地分布,并且高吞吐量光纤网络可用。

发明内容

本公开的实施例的方面和优点将部分地在下文的描述中阐述,或者可以从说明书中学习,或者可以通过实践实施例而学习。

本公开的一个示例方面针对于一种用于通信高效机器学习的计算机实现的方法。该方法包括:由客户端计算设备获得机器学习模型的参数集的全局值。该方法包括:由客户端计算设备至少部分地基于本地数据集来训练机器学习模型以获得更新矩阵,更新矩阵描述机器学习模型的参数集的经更新值。更新矩阵被约束为具有预定义的结构。本地数据集由客户端计算设备本地存储。该方法包括:由客户端计算设备将描述更新矩阵的信息传送到服务器计算设备。

本公开的另一示例方面涉及一种客户端计算设备。该客户端设备包括至少一个处理器;以及,存储指令的至少一个非暂时计算机可读介质,该指令在由至少一个处理器执行时使得客户端计算设备执行操作。该操作包括获得机器学习模型的参数集的全局值。该操作包括:至少部分地基于本地数据集来训练机器学习模型以获得更新矩阵,该更新矩阵描述机器学习模型的参数集的经更新值。本地数据集由客户端计算设备本地存储。该操作包括编码更新矩阵以获得经编码更新。该操作包括将经编码更新传送到服务器计算设备。

本公开的另一示例方面涉及存储指令的至少一个非暂时性计算机可读介质,该指令在由客户端计算设备执行时使得客户端计算设备实现操作。该操作包括获得机器学习模型的参数集的全局值。该操作包括至少部分地基于本地数据集来训练机器学习模型以获得更新矩阵,该更新矩阵描述机器学习模型的参数集的经更新值。本地数据集由客户端计算设备本地存储。更新矩阵被约束为低秩矩阵和稀疏矩阵中的至少一种。该操作包括编码更新矩阵以获得经编码更新。操作包括将经编码更新传送到服务器计算设备。

本公开的其它方面涉及各种系统、装置、非暂时性计算机可读介质、用户界面、和电子设备。

在本说明书中描述的主题的特定实施例可以被实现以便达到以下技术效果中的一个或多个。例如,在各个实施方式中,可以通过在多个客户端计算设备(例如,用户移动设备)上本地训练机器学习模型来实现有效率处理,从而充分利用多个设备的计算能力。此外,由客户端计算设备基于本地数据集来训练机器学习模型,可以提高训练过程的安全性。这是因为,例如,模型更新的信息相较于数据本身较不敏感。隐私敏感的用户数据保留在用户的计算设备上,并且不会上传到服务器。相反,只有较不敏感的模型更新被传输。

此外,由于互联网连接的不对称性质,从客户端到服务器的上行链路可能比从服务器到客户端的下行链路更慢,使得每个客户端设备传输完整的、无限制模型可能是没有效率的。然而,通过根据各个实施例(例如通过将更新矩阵约束为低秩矩阵和稀疏矩阵中的至少一个)来限制更新矩阵具有预定义的结构,达到了机器学习框架内的更有效率通信,例如从网络的角度、从客户端设备工作负荷/资源的角度、和/或从试图执行尽可能多轮的学习/尽可能快更新的角度。

参考以下描述和所附权利要求书,将更好地理解本公开的各种实施例的这些和其它特征、方面、和优点。并入并构成本说明书的一部分的附图图示了本公开的示例实施例,并且与描述一起用作解释相关原理。

附图说明

在参考附图的说明书中阐述了针对本领域普通技术人员的实施例的详细说明,在附图中:

图1描绘了根据本公开的示例实施例的示例计算系统的框图。

图2描绘了根据本公开的示例实施例的示例计算系统的框图。

图3描绘了根据本公开的示例实施例的,用于执行通信高效机器学习的示例方法的流程图。

图4描绘了根据本公开的示例实施例的,用于执行通信高效机器学习的示例方法的流程图。

图5描绘了根据本公开的示例实施例的,用于执行通信高效机器学习的示例方法的流程图。

图6描绘了根据本公开的示例实施例的示例实验结果的曲线图。

具体实施方式

总体而言,本公开提供了在机器学习框架内执行有效率通信的系统和方法。例如,本文提供的有效率通信技术可以在联合学习框架的场境(context)中实现。然而,本文提供的有效率通信技术也可以在其它机器学习框架中实现。例如,本文提供的有效率通信技术可以在任何实例下实现,其中第一计算设备负责计算对模型的更新和/或将模型的更新传送到第二计算设备。

更具体地,联合学习是机器学习框架,其使得能够基于在大量客户端计算设备上分布的训练数据来训练高品质集中化模型。客户端通常具有低计算能力和/或与网络的慢/不稳定连接。在某些实例下,联合学习也可以被称为“联合优化”。

对联合学习有推动作用的示例产生于当训练数据本地保持在用户的移动计算设备上并且将这样的移动计算设备用作为其对本地数据执行计算的节点以更新全局模型的时候。因此,联合学习框架与常规分布式机器学习有所不同,因为客户端数目众多,数据高度不平衡并且不是独立和完全相同分布(“IID”),以及网络连接不可靠。

与在集中化服务器上执行学习相比,联合学习提供了若干明显的优势。例如,模型更新的信息相比数据本身较不敏感。因此,隐私敏感的用户数据保留在用户的计算设备上,并且不会上传到服务器。相反,只有较不敏感的模型更新被传输。作为另一个优点,联合学习充分利用大量计算设备(例如,用户移动设备)的计算能力。此外,对上述框架的轻微修改可以导致在其相应的设备处为每个用户创建和使用个性化模型。

在某些实施方式中,实现联合学习的系统可以在多轮模型优化中的每一轮中执行以下动作:选择客户端的子集;子集中的每个客户端基于其本地数据来更新模型;更新的模型或模型更新由每个客户端发送到服务器;服务器聚合更新(例如,通过对更新进行平均)并改进全局模型;并且服务器将全局模型重新分发给所有客户端。执行多轮上述动作基于存储在客户端设备处的训练数据而迭代地改进全局模型。

在上述框架的基本实施方式中,每个客户端设备在每一轮中将完整的模型发送回服务器。然而,每个客户端设备的完整模型的传输显然是“昂贵的”(例如,从网络的角度来看,从客户端设备的工作负荷/资源的角度来看,和/或从尝试执行尽可能多轮的学习/尽可能快的更新的角度来看)。特别地,由于互联网连接的不对称性质,从客户端到服务器的上行链路通常比从服务器到客户端的下行链路慢得多。

鉴于这样的显著的上传成本,本公开提供了在机器学习框架内(例如,在上文讨论的联合学习框架内)执行有效率通信的系统和方法。特别地,本公开提供了有效率通信技术,其降低了从客户端设备向服务器传送更新的模型或模型更新的成本。

更具体地,本公开提供通信高效方案,其包括:结构化更新方案,该方案中模型更新被约束为小而概略(sketched)的更新方案,其中模型更新在发送到服务器之前被压缩。这些方案可以组合起来,例如通过首先学习结构化更新并且然后对其进行概略。

在讨论由本公开提供的通信高效方案之前,将讨论联合学习环境中的模型更新的基本实施方式,并将介绍符号。

在一个示例联合学习框架中,目标是学习具有在实矩阵

在轮t≥0时,服务器将当前模型W

然后,每个客户端将更新发送回服务器,其中通过聚合所有客户端侧更新来计算全局更新。例如,一个聚合方式(aggregation scheme)可以是如下:

此外,在一些实施方式中,可以使用加权和基于期望的效能来替换平均值。服务器也可以选择学习速率η

在执行深度学习的实施方式中,可以使用单独的矩阵W来表示每个层的参数。因此,本文中每次对矩阵的提及(例如,模型参数矩阵

在许多实例下,W和

根据本公开的方面,提供了降低向服务器发送

由本公开提供的第一类型的通信高效更新是其中

在低秩结构化更新技术中,每个更新

作为一个示例,

在一些实施方式中,固定

在一些实施方式中,低秩结构化更新技术可以提供k/d

在随机掩码结构化更新技术中,每个更新

类似于低秩方案,可以基于随机种子来完全指定或以其它方式生成稀疏模式。因此,客户端设备只需要发送

由本公开提供的第二类型的通信高效更新是其中客户端在发送到服务器之前以压缩形式对更新

本公开设想到许多不同类型的编码或压缩。例如,压缩可以是无损压缩或有损压缩。下面更详细地描述两个示例编码技术:子采样技术和量化技术。

对更新进行编码的一种方式是仅对由更新描述的参数的随机子集进行采样。换言之,代替传送完整更新

然后,服务器聚合采样的更新(例如,通过取平均),产生全局更新

在一些实施方式中,可以通过使用例如如上所描述的随机掩码来选择被采样的参数部分。特别地,可以使用种子来形成参数掩码,其识别哪些参数被采样。可以为每轮和/或为每轮中的每个客户端生成不同的掩码。掩码本身可以存储为同步的种子。通过零均值的随机矩阵扰乱SGD的预期迭代——这是子采样策略的某些实施方式的做法——不会影响这种类型的收敛。

对更新进行编码的另一种方式是通过量化权重。例如,可以概率性地量化权重。

首先,将描述用于将每个标量值量化为1位的算法。考虑更新

容易示出

此外,对于每个标量,上述可以被推广到超过1位。例如,对于b位量化,可以将[h

在一些实施方式中,区间不一定是均匀间隔的,而是作为替代可以被动态间隔。对于使用动态间隔区间的实施方式,客户端可以向服务器提供标识每个区间的max(最大)/min(最小)值的表。客户端通常会将min和max传输到服务器,而不管是否使用动态间隔。在一些实施方式中,可以在量化更新设置中类似地使用增量、随机、和/或分布式优化算法。

根据本公开的另一方面,可以通过使用随机旋转来改进上述量化技术。更具体地,当标量跨不同维度而近似相等时,上述1位和多位量化方案效果最好。例如,当max=100和min=-100并且大多数值为0时,1位量化将导致大的量化误差。

因此,根据本公开的一个方面,量化之前执行h的随机旋转可以解决这个问题。例如,更新可以乘以可逆矩阵R。例如,旋转矩阵可以是正交矩阵。在量化之前随机旋转h可以跨区间更均匀地分布标量值。在解码阶段,服务器可以在聚合所有更新之前执行逆旋转。

在一些实施方式中,h的维数可以高达d=1M,并且在旋转矩阵中生成

在一些实施方式,相同的旋转矩阵用于所有客户端设备和/或用于所有轮次。在一些实施方式,服务器然后可以在对经聚合的更新进行逆旋转之前对更新执行一些聚合。在一些实施方式,服务器在聚合之前解码每个更新。

图1描绘了用于使用在本地存储在多个客户端设备102上的相应训练数据108来训练一个或多个全局机器学习模型106的示例系统100。系统100可以包括服务器设备104。服务器104可被配置为访问机器学习模型106,并且将模型106提供给多个客户端设备102。例如,模型106可以是线性回归模型、逻辑回归模型、支持向量机模型、神经网络(例如卷积神经网络,递归神经网络等)、或其它合适的模型。在一些实施方式中,服务器104可以被配置为通过一个或多个网络与客户端设备102进行通信。

客户端设备102均可以被配置为至少部分地基于训练数据108来确定与模型106相关联的一个或多个本地更新。例如,训练数据108可以是分别本地地存储在客户端设备106上的数据。训练数据108可以包括音频文件、图像文件、视频文件、键入历史、位置历史、和/或各种其它合适的数据。在一些实施方式中,训练数据可以是通过与客户端设备102的用户交互导出的任何数据。

除了上面的描述之外,用户可以被提供有控件来允许用户对本文所描述的系统、程序或特征是否以及何时可以实现收集、存储、和/或使用用户信息(例如,训练数据108)以及用户是否从服务器发送内容或通信二者做出选择。此外,某些数据可以在存储或使用之前以一种或多种方式处理,使得移除个人可识别信息。例如,可以对用户的可识别进行处理,使得无法确定用户个人可识别信息,或者可以在获得位置信息的情况下将用户的地理位置泛化(诸如到城市,邮政编码、或州层级),使得无法确定用户的具体位置。因此,用户可以控制关于用户收集什么信息、如何使用该信息以及向用户提供什么信息。

虽然训练数据108在图1中被图示为单个数据库,但是训练数据108由分别存储在每个设备102处的数据组成。因此,在一些实施方式中,训练数据108是高度不平衡的,并且不是独立和完全相同分布的。

客户端设备102可以被配置为向服务器104提供本地更新。如上所述,训练数据108可以是隐私敏感的。以这种方式,可以执行本地更新并将其提供给服务器104,而不会损害到训练数据108的隐私性。例如,在这样的实施方式中,训练数据108并不提供给服务器104。本地更新不包括训练数据108。在将本地更新的模型提供给服务器104的一些实施方式中,一些隐私敏感数据可能能够从模型参数导出或推断。在这样的实施方式中,可以将加密技术、随机噪声技术、和/或其它安全技术中的一个或多个技术添加到训练过程中以模糊任何可推断的信息。

如上所述,服务器104可以从客户端设备102接收每个本地更新,并且可以聚合该本地更新以确定对模型106的全局更新。在一些实施方式,服务器104可以确定本地更新的平均值(例如,加权平均值)并且至少部分地基于该平均值来确定全局更新。

在一些实施方式,扩缩或其它技术可以应用于本地更新以确定全局更新。例如,可以为每个客户端设备102应用本地步长,可以与客户端设备102的各个数据分区大小成比例地来执行聚合,和/或一个或多个扩缩因子可以应用于本地更新和/或经聚合更新。应当理解,可以应用各种其它技术而不偏离本公开的范围。

图2描绘了可用于实现本公开的方法和系统的示例计算系统200。系统200可以使用客户端-服务器架构来实现,该客户端-服务器架构包括通过网络242与一个或多个客户端设备230进行通信的服务器210。因此,图2提供了可以实现由图1的系统100所图示的方式的示例系统200。

系统200包括诸如web服务器的服务器210。服务器210可以使用任何合适的计算设备来实现。服务器210可以具有一个或多个处理器212和一个或多个存储器设备214。可以使用一个服务器设备或多个服务器设备来实现服务器210。在使用多个设备的实施方式中,这样的多个设备可以根据并行计算架构、串行计算架构、或其组合来操作。

服务器210还可以包括用于通过网络242与一个或多个客户端设备230通信的网络接口。网络接口可以包括用于与一个或多个网络对接的任何合适的组件,包括例如发射器、接收器、端口、控制器、天线、或其它合适的组件。

一个或多个处理器212可以包括任何合适的处理设备,诸如微处理器、微控制器、集成电路、逻辑设备、或其它合适的处理设备。一个或多个存储器设备214可以包括一个或多个计算机可读介质,包括但不限于非暂时性计算机可读介质、RAM、ROM、硬盘驱动器、闪存驱动器、或其它存储器设备。一个或多个存储器设备214可以存储可由一个或多个处理器212访问的信息,包括可由一个或多个处理器212执行的计算机可读指令216。

指令216可以是指令的任何集合,该指令在由一个或多个处理器212执行时使得一个或多个处理器212执行操作。例如,指令216可以由一个或多个处理器212执行以实现全局更新器220。全局更新器220可以被配置为接收一个或多个本地更新并且至少部分地基于该本地更新来确定全局模型。

指令216还可以包括使得服务器210实现解码器222的指令。解码器222可以解码已经由客户端设备230编码的更新(例如,根据上文讨论的编码技术之一,诸如子采样、量化、随机旋转等)。

如图2所示,一个或多个存储器设备214还可以存储数据218,其可由一个或多个处理器212检索、操纵、创建、或存储。数据218可以包括例如本地更新、全局参数、和其它数据。数据218可以存储在一个或多个数据库中。一个或多个数据库可以通过高带宽LAN或WAN连接到服务器210,或者还可以通过网络242连接到服务器210。一个或多个数据库可以被拆分,使得它们位于多个地点中。

服务器210可以通过网络242与一个或多个客户端设备230交换数据。任何数目的客户端设备230可以通过网络242连接到服务器210。客户端设备230中的每一个可以是任何合适类型的计算设备,诸如通用计算机、专用计算机、膝上型计算机、台式计算机、移动设备、导航系统、智能电话、平板计算机、可穿戴计算设备、游戏控制台、具有一个或多个处理器的显示器、或其它合适的计算设备。

与服务器210类似,客户端设备230可以包括一个或多个处理器232和存储器234。一个或多个处理器232可以包括例如一个或多个中央处理单元(CPU)、专用于有效率地渲染图像或执行其它专门计算的图形处理单元(GPU)、和/或其它处理设备。存储器234可以包括一个或多个计算机可读介质并且可以存储可由一个或多个处理器232访问的信息,包括可由一个或多个处理器232执行的指令236和数据238。

指令236可以包括用于实现本地更新器的指令,本地更新器根据本公开的示例方面配置成确定一个或多个本地更新。例如,本地更新器可以执行一种或多种训练技术,诸如后向传播误差以基于本地存储的训练数据来重新训练或以其它方式更新模型。本地更新器可以被配置为执行结构化更新、概略更新、或其它技术。本地更新器可以被包括在应用中,或者可以被包括在设备230的操作系统中。

指令236还可以包括用于实现编码器的指令。例如,编码器可以执行上述编码技术中的一个或多个(例如,子采样,量化,随机旋转等)。

数据238可以包括用于解决一个或多个优化问题的一个或多个训练数据示例。每个客户端设备230的训练数据示例可以在客户端设备中不均匀地分布,使得客户端设备230不包括训练数据示例的总体分布的代表性样本。

数据238还可以包括要传送到服务器210的更新的参数。

图2的客户端设备230可以包括用于提供和接收来自用户的信息的各个输入/输出设备,诸如触摸屏、触摸板、数据录入键、扬声器、和/或适合于语音识别的麦克风。

客户端设备230还可以包括用于通过网络242与一个或多个远程计算设备(例如,服务器210)进行通信的网络接口。网络接口可以包括用于与一个或多个网络对接的任何合适的组件,包括例如发射器、接收器、端口、控制器、天线、或其它合适的组件。

网络242可以是任何类型的通信网络,诸如局域网(例如内联网)、广域网(例如互联网)、蜂窝网络、或其某些组合。网络242还可以包括在客户端设备230与服务器210之间的直接连接。通常,可以使用任何类型的有线和/或无线连接、使用各种通信协议(例如TCP/IP、HTTP、SMTP、FTP)、编码或格式(例如HTML、XML)、和/或保护方式(如VPN、安全HTTP、SSL)经由网络接口来实施服务器210与客户端设备230之间的通信。

图3描绘了根据本公开的示例实施例的,确定全局模型的示例方法(300)的流程图。方法(300)可以由一个或多个计算设备来实现,该计算设备诸如图1和/或2所描绘的计算设备中的一个或多个。此外,图3描绘了为了说明和讨论的目的而以特定顺序执行的步骤。本领域普通技术人员使用本文提供的公开将理解,本文讨论的任何方法的步骤可以以各种方式进行调整、重新排列、扩展、省略、或修改,而不脱离本公开的范围。

在(302),方法(300)可以包括由客户端设备基于一个或多个本地数据示例来确定本地模型。特别地,可以使用一个或多个数据示例来针对损失函数确定本地模型。例如,数据示例可以通过用户与客户端设备交互来生成。在一些实施方式中,模型可能已经在(302)的本地训练之前被预先训练。在一些实施方式中,可以在(302)处使用结构化更新、概略更新、或其它技术来使所学习的本地模型或本地更新变得通信高效。

在(304),方法(300)可以包括由客户端设备将本地模型提供给服务器,并且在(306),方法(300)可以包括由服务器接收本地模型。在一些实施方式中,可以在将本地模型或本地更新发送到服务器之前对该本地模型或本地更新进行编码或压缩。

在(308),方法(300)可以包括由服务器至少部分地基于所接收的本地模型来确定全局模型。例如,可以至少部分地基于由多个客户端设备提供的多个本地模型来确定全局模型,所述多个客户端设备均具有多个不均匀分布的数据示例。特别地,数据示例可以分布在客户端设备中,使得客户端设备不包括数据的总体分布的代表性样本。此外,客户端设备的数目可以超过任何一个客户端设备上的数据示例的数目。

在一些实施方式中,作为聚合过程的一部分,服务器可以解码每个接收到的本地模型或本地更新。

在(310),方法(300)可以包括向每个客户端设备提供全局模型,并且在(312),方法(300)可以包括接收全局模型。

在(314),方法(300)可以包括由客户端设备确定本地更新。在一个特定实施方式中,可以基于本地存储的训练数据重新训练或以其它方式更新全局模型来确定本地更新。在一些实施方式,可以在(314)处使用结构化更新、概略更新、或其它技术来使所学习的本地模型或本地更新变得通信高效。

在一些实施方式中,可以至少部分地基于使用一个或多个随机更新或迭代来确定本地更新。例如,客户端设备可以随机对存储在客户端设备上的数据示例的分区进行采样来确定本地更新。特别地,可以使用随机模型下降技术来确定本地更新来确定调整损失函数的一个或多个参数的方向。

在一些实施方式中,可以至少部分地基于存储在客户端设备上的数据示例的数目来确定与本地更新确定相关联的步长。在进一步的实施方式中,随机模型可以使用对角矩阵或其它扩缩技术进行扩缩。在另外的实施方式中,可以使用强制每个客户端设备在相同方向上更新损失函数的参数的线性项来确定本地更新。

在(316)中,方法(300)可以包括由客户端设备向服务器提供本地更新。在一些实施方式中,可以在将本地模型或更新发送到服务器之前对本地模型或本地更新进行编码。

在(318),方法(300)可以包括由服务器接收本地更新。具体地,服务器可以从多个客户端设备接收多个本地更新。

在(320),方法(300)可以包括再次确定全局模型。特别地,可以至少部分地基于所接收的本地更新来确定全局模型。例如,可以聚合所接收到的本地更新以确定全局模型。聚合可以是加法聚合和/或平均聚合。在特定实施方式中,本地更新的聚合可以与客户端设备上的数据示例的分区大小成比例。在另外的实施例中,本地更新的聚合可以以每坐标方式来进行扩缩。

可以执行任何次数的对本地更新和全局更新迭代。即,可以迭代地执行方法(300)以随时间推移基于本地存储的训练数据来更新全局模型。

图4描绘了根据本公开的示例实施例的,用于执行通信高效机器学习的示例方法400的流程图。例如,方法400可以由客户端计算设备执行。

在402,客户端计算设备获得机器学习模型的参数集的全局值。

在404,客户计算设备至少部分地基于本地数据集来训练机器学习模型以获得更新矩阵,该更新矩阵描述机器学习模型的参数集的经更新值。更新矩阵被限制为具有预定义的结构。本地数据集由客户端计算设备本地存储。在一些实施方式中,更新矩阵描述了参数集的经更新值和/或经更新值与全局值之间差异。

在一些实施方式,更新矩阵被限制为低秩矩阵。

在一些实施方式中,在404处训练机器学习模型可以包括由客户端计算设备将更新矩阵定义为第一矩阵和第二矩阵的乘积。第一矩阵可以包括固定值,并且第二矩阵可以包括可优化的变量。客户端计算设备可以至少部分地基于本地数据集来训练机器学习模型以获得第二矩阵。

在一些这样的实施方式中,方法400还可以包括:在404处训练模型之前:至少部分地基于种子和伪随机数生成器来生成第一矩阵。客户端计算设备和服务器计算设备可以均知晓种子,使得第一矩阵可由服务器计算设备重现。

在一些实施方式中,更新矩阵被限制为稀疏矩阵。

在一些实施方式中,在404处训练机器学习模型可以包括至少部分地基于本地数据集来训练机器学习模型,使得仅针对参数集的预选部分确定经更新值。在这样的实施方式中,更新矩阵可以仅描述参数集的预选部分的经更新值。

在一些这样的实施方式中,方法400还可以包括:在404处训练模型之前:生成参数掩码,该参数掩码指定参数集中的哪些参数被包括在该参数集的预选部分中。例如,生成参数掩码可以包括由客户端计算设备至少部分地基于种子和伪随机数生成器来生成参数掩码,其中客户端计算设备和服务器计算设备均知晓种子,使得参数掩码可由服务器计算设备重现。

在406,客户端计算设备将描述更新矩阵的信息传送到服务器计算设备。

作为一个示例,在404处训练模型包括优化第二矩阵的实施方式中,在406处传送描述更新矩阵的信息可以包括将描述第二矩阵的信息传送到服务器计算设备。

图5描绘了根据本公开的示例实施例的,用于执行通信高效机器学习的示例方法500的流程图。例如,方法500可以由客户端计算设备执行。

在502,客户端计算设备获得机器学习模型的参数集的全局值。

在504,客户计算设备至少部分地基于本地数据集来训练机器学习模型以获得更新矩阵,该更新矩阵描述机器学习模型的参数集的经更新值。该本地数据集由客户端计算设备本地存储。在一些实施方式中,该更新矩阵描述了参数集的经更新值和/或经更新值与全局值之间的差异。

在506,客户端计算设备对更新矩阵进行编码以获得经编码更新。

在一些实施方式中,在506处编码更新矩阵可以包括对更新矩阵进行子采样以获得经编码更新。在一些这样的实施方式中,对更新矩阵进行子采样可以包括:生成参数掩码,该参数掩码指定要采样的参数集的一部分;并根据该参数掩码对更新矩阵进行子采样。

在一些这样的实施方式中,生成参数掩码可以包括至少部分地基于种子和伪随机数生成器来生成参数掩码,其中客户端计算设备和服务器计算设备均知晓种子,使得参数掩码可由服务器计算设备重现。

在一些实施方式中,在506处编码更新矩阵可以包括对包括在更新矩阵中的一个或多个值进行概率性量化。在一些实施方式中,在506处编码更新矩阵可以包括对在更新矩阵中包括的一个或多个值执行概率性二进制量化,以将一个或多个值中的每一个改变为在更新矩阵中所包括的最大值或在更新矩阵中所包括的最小值。在一些实施方式中,在506处编码更新矩阵可以包括:在在更新矩阵中所包括的最大值与在更新矩阵中所包括的最小值之间定义多个区间;并且将在更新矩阵中所包括的一个或多个值概率性地改变为局部区间最大值或局部区间最小值。

在一些实施方式中,在506处编码更新矩阵可以包括将更新矩阵的向量乘以旋转矩阵以获得旋转更新。在一些这样的实施方式中,编码更新矩阵还可以包括概率性地量化旋转更新中所包括的一个或多个值。在一些实施方式中,旋转矩阵可以是不需要客户端计算设备完全生成旋转矩阵的结构化旋转矩阵。

在508,客户端计算设备将经编码更新传送到服务器计算设备。该服务器计算设备可以对经编码更新进行解码。

使用联合学习进行示例实验来训练用于CIFAR-10图像分类任务的深层神经网络(参见Krizhevsky.Learning multiple layers of features from tiny images(根据微小图像来学习多层的特征).Technical report,2009)。存在50000个训练示例,其分成100个客户端,每个客户端包含500个训练示例。模型架构取自TensorFlow教程(Tensorflowconvolutional neural networks tutorial(Tensorflow卷积神经网络教程).http://www.tensorflow.org/tutorials/deep_cnn,2016),其由两个卷积层组成,后面是两个全连接层,并且然后是线性变换层以为总共超过1e6个参数产生logit。虽然这种模型不是最先进的,但是不同于为了实现此任务的最佳可能准确性,其足够用于评估本文描述的压缩方法的目的。

采用联合取平均算法(McMahan等人。Federated Learning of deep networksusing model platforms(使用模型平台的深度网络联合学习).ArXiv:1602.05629,160),这大大减少了训练良好模型所需的通信轮次。然而,当应用于同步SGD时,这些技术预期将示出通信成本的类似降低。对于联合平均化,对于总共100个本地更新,每轮随机选择10个客户端,其中每一个使用50个图像的小批次在其本地数据集上以η学习速率来执行10期SGD。从该经更新模型中,计算

表1提供了示例CIFAR实验的低秩和采样参数。采样概率列给出了分别针对两个卷积层和两个全连接层上传的元素部分;这些参数由StructMask、SketchMask、和SketchRotMask使用。低秩列给出了这四个层的秩限制k。最后的softmax层小,所以对其的更新没有被压缩。

表1:

图6描绘了示例非量化结果(左列和中间列)以及包括二进制量化的结果(虚线SketchRotMaskBin和SketchMaskBin,右列)的曲线图。注意,右上曲线图的x轴是对数标度。利用少于100MB的通信来实现了超过70%的准确度。

定义了中和高低秩/采样参数设定,使得这两种方案的压缩率相同,如表1所给出。图6的左列和中心列显示了测试集准确度的非量化结果,既作为算法轮数的函数,并且也作为上传的总兆字节数的函数。对于所有实验,使用以0.15为中心的分辨率

对于中等子采样,在固定的带宽使用量之后,所有这三种方案提供了测试集准确度的大大提高;除了StructLowRank方案对于高子采样参数表现较差外,下排的曲线图作为更新轮数的函数示出准确度的极少损失。

图6中右边的两个曲线图给出了在有二进制量化和没有二进制量化的情况下SketchMask和SketchRotMask的结果;只考虑中等子采样方式,其是代表性的。观察到(如预期)在没有量化的情况下引入随机旋转基本上没有影响。然而,二进制量化大大降低了总通信成本,并且进一步引入随机旋转显着地加速了收敛,并且还允许收敛到更高准确度水平。能够以仅约100MB的通信来学习合理的模型(70%的准确度),比基线小两个数量级。

本文讨论的技术引用了服务器、数据库、软件应用、和其它基于计算机的系统,以及所采取的措施和向这样的系统发送的信息和从这样的系统发送的信息。基于计算机的系统的固有灵活性允许组件之间和组件当中的各种可能的任务和功能的配置、组合、和划分。例如,可以使用单个设备或组件或组合地工作的多个设备或组件来实现本文讨论的过程。数据库和应用程序可以在单个系统上实现或跨多个系统分布。分布式组件可以顺序或并行地操作。

虽然已经关于本主题的各种具体示例实施例详细描述了本主题,但是通过说明而不是限制本公开来提供每个示例。本领域技术人员在理解上述内容之后可以容易地针对这样的实施例做出改变、变型、和等同物。因此,本公开不排除对本领域的普通技术人员显而易见的对本主题的这样的修改、变型和/或添加。例如,作为一个实施例的一部分图示或描述的特征可以与另一个实施例一起使用以产生又一个实施例。因此,本公开旨在涵盖这样的改变、变型、和等同物。

去获取专利,查看全文>

相似文献

  • 专利
  • 中文文献
  • 外文文献
获取专利

客服邮箱:kefu@zhangqiaokeyan.com

京公网安备:11010802029741号 ICP备案号:京ICP备15016152号-6 六维联合信息科技 (北京) 有限公司©版权所有
  • 客服微信

  • 服务号