首页> 中国专利> 一种基于SeqGAN的深度强化学习数据增强防御方法和装置

一种基于SeqGAN的深度强化学习数据增强防御方法和装置

摘要

本发明公开了一种基于SeqGAN的深度强化学习数据增强防御方法和装置,首先,利用DQN算法对目标智能体进行预训练,再采样多个T时刻的序列状态动作对作为真实序列轨迹数据;其次将初始状态动作对输入到生成器中利用生成器来生成序列状态策略轨迹数据,使用判别器和基于策略梯度的蒙特卡洛法来评估生成的序列得到的奖励,用于引导生成器的训练,以用于生成接近真实的状态动作数据;最后通过比较由模型策略得到的序列累计奖励值和SeqGAN生成的状态策略得到的累计奖励值大小来对训练数据进行增强,以优化模型的策略,提高模型的鲁棒性。

著录项

  • 公开/公告号CN112884130A

    专利类型发明专利

  • 公开/公告日2021-06-01

    原文格式PDF

  • 申请/专利权人 浙江工业大学;

    申请/专利号CN202110281225.9

  • 发明设计人 陈晋音;章燕;王雪柯;胡书隆;

    申请日2021-03-16

  • 分类号G06N3/04(20060101);G06N3/08(20060101);

  • 代理机构33224 杭州天勤知识产权代理有限公司;

  • 代理人曹兆霞

  • 地址 310014 浙江省杭州市下城区潮王路18号

  • 入库时间 2023-06-19 11:11:32

说明书

技术领域

本发明属于面向深度强化学习的防御领域,具体涉及一种基于SeqGAN的深度强化学习数据增强防御方法和装置。

背景技术

深度强化学习是近年来人工智能备受关注的方向之一,随着强化学习的快速发展和应用,强化学习已经在机器人控制、游戏博弈、计算机视觉、无人驾驶等领域被广泛使用。为了保证深度强化学习在安全攸关领域的安全应用,关键在于分析、发现深度强化学习算法、模型中的漏洞以防止一些别有用心的人利用这些漏洞来进行非法牟利行为。不同于传统机器学习的单步预测任务,深度强化学习系统需要进行多步决策来完成某项任务,而且连续的决策之间还具有高度的相关性。

强化学习通常是一个连续决策的过程,其基本思想是通过最大化智能体从环境中获取的累计奖励,从而学习最优策略以达到学习目的。深度强化学习充分利用神经网络作为参数结构,结合了深度学习的感知能力和强化学习的决策能力来优化深度强化学习策略,最终实现从感知输入到决策输出的端到端深度强化学习框架,具有较强的学习能力且应用广泛。但是与机器学习模型相同的是,强化学习也容易受到对抗样本的干扰,根据对样本的攻击,主要可分为观测攻击、奖励攻击、动作攻击以及环境攻击。同时,深度强化学习的训练过程中也会受到中毒攻击,使得模型的训练代价加大,模型的性能降低。此外,深度强化学习训练的策略还存在安全性隐患,尤其是在安全关键型领域,如复杂作战环境、作战指挥环境、无人机装甲车、无人机监察机、智能机器人控制等,模型策略漏洞的存在会给强化学习系统带来严重的损害。尤其是在安全关键型领域中,这种安全性隐患带来了很大的危害,会使强化学习系统的决策发生错误,这对于强化学习的决策安全应用领域是重大挑战。

已有研究表明,通过策略中毒攻击可以通过改变训练集中的数据来使决策发生改变,从未使得智能体动作选取失误,智能体最终达不到学习目的。这种攻击对于无人驾驶等安全决策领域的应用是十分致命的。目前,根据现有的防御机制,常见的强化学习的防御方法可以分为对抗训练、鲁棒学习、对抗检测三大类。对抗训练是指将对抗样本加入到训练样本中对模型进行训练,其主要目的是提高策略对正常样本以外的泛化能力。但是对抗训练往往只能提高策略对参与训练的样本的拟合能力。鲁棒学习是训练模型在面对来自训练阶段或者测试阶段时的攻击方法时提高其自身鲁棒性的学习机制。对抗检测指模型对正常样本与对抗样本加以甄别,并在不修改原始模型参数的情况下处理对抗样本,来实现防御效果。

发明内容

鉴于深度强化学习在安全决策领域(例如自动驾驶场景)由于容易受到噪声扰动攻击而引起的安全威胁问题,本发明的目的是提供一种基于SeqGAN的深度强化学习数据增强防御方法和装置。通过数据增强的方式来优化深度强化学习模型,提升深度强化学习模型的鲁棒性,以防御攻击。

为实现上述发明目的,本发明提供以下技术方案:

第一方面,一种基于SeqGAN的深度强化学习数据增强防御方法,包括以下步骤:

搭建深度强化学习的智能体自动驾驶模拟环境,基于强化学习中的深度Q网络构建目标智能体,并对目标智能体进行强化学习以优化深度Q网络的参数;

利用参数优化的深度Q网络产生T个时刻的目标智能体驾驶的状态动作对序列作为专家数据,其中,状态动作对中的动作取值对应Q值最小的动作;

利用强化学习的方法来训练包含生成器和判别器的SeqGAN,以专家数据中状态动作对作为生成器的输入来生成状态动作对,同时采用基于策略梯度蒙特卡洛搜索来模拟采样,采样得到的状态动作对与生成器生成的状态动作对组成固定长度的状态动作对序列并输入至判别器,计算奖励值,依据该奖励值更新SeqGAN的网络参数;

将当前状态输入至参数优化的SeqGAN的生成器中以获得生成状态动作对序列,利用参数优化的深度Q网络来计算生成状态动作对序列的累计奖励值,将该累计奖励值与目标智能体的深度Q网络策略得到的累计奖励值进行比较,以累计奖励值更高的状态动作对作为增强数据存储用于对深度Q网络再优化;

从存储中选择增强数据对深度Q网络进行参数再优化,以实现深度强化学习数据增强防御。

第二方面,一种基于SeqGAN的深度强化学习数据增强防御装置,包括计算机存储器、计算机处理器以及存储在所述计算机存储器中并可在所述计算机处理器上执行的计算机程序,所述计算机处理器执行计算机程序时实现上述基于SeqGAN的深度强化学习数据增强防御方法。

与现有技术相比,本发明基于SeqGAN的深度强化学习数据增强防御方法和装置具有的有益效果至少包括:

1)通过训练SeqGAN来生成序列状态和策略动作以增强训练数据以优化目标智能体策略,提高DRL模型的鲁棒性;2)在SeqGAN训练过程中,生成器用来生成序列状态和策略动作轨迹数据,判别器的真实数据输入是采样自预训练的DQN模型的序列状态动作对数据,训练过程中通过更新生成器和判别器的参数来生成更为真实的序列状态策略数据;3)通过比较由模型策略得到的序列累计奖励值和SeqGAN生成的状态策略得到的累计奖励值大小来对训练数据进行增强,以优化模型的策略,提高模型的鲁棒性。

附图说明

为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图做简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动前提下,还可以根据这些附图获得其他附图。

图1是实施例提供的基于SeqGAN的深度强化学习数据增强防御方法的流程图;

图2是实施例提供的基于SeqGAN训练的序列数据生成过程示意图;

图3是实施例提供的化学习中DQN算法结构的示意图。

具体实施方式

为使本发明的目的、技术方案及优点更加清楚明白,以下结合附图及实施例对本发明进行进一步的详细说明。应当理解,此处所描述的具体实施方式仅仅用以解释本发明,并不限定本发明的保护范围。

在模拟小车自动驾驶的深度强化学习训练过程中,基于训练数据集中毒的攻击方法会使学习者学习到一个错误的策略,从而选择一个不好的动作,使得学习者学习错误。基于此种情况,实施例提供了一种基于SeqGAN的深度强化学习数据增强防御方法和装置,利用SeqGAN来生成序列状态动作数据来进行训练数据增强以达到优化目标智能体策略的目的。具体过程为:首先,利用DQN算法对目标智能体进行预训练,再采样多个T时刻的序列状态动作对作为真实序列轨迹数据;其次将初始状态动作对输入到生成器中利用生成器来生成序列状态策略轨迹数据,使用判别器和基于策略梯度的蒙特卡洛法来评估生成的序列得到的奖励,用于引导生成器的训练,以用于生成接近真实的状态动作数据;最后通过比较由模型策略得到的序列累计奖励值和SeqGAN生成的状态策略得到的累计奖励值大小来对训练数据进行增强,以优化模型的策略,提高模型的鲁棒性。

图1是实施例提供的基于SeqGAN的深度强化学习数据增强防御方法的流程图。如图1所示,实施例提供的深度强化学习数据增强防御方法包括以下步骤:

步骤1,搭建深度强化学习的智能体自动驾驶模拟环境,基于强化学习中的深度Q网络(DQN)构建目标智能体,并对目标智能体进行强化学习以优化深度Q网络的参数。

实施例中,智能体可以是自动驾驶环境中的小车,在利用强化学习中的深度Q网络训练小车,目的是使小车尽可能快递达到目的地。深度Q网络是将Q学习和卷积神经网络相结合,构建用于深度强化学习的训练模型。

其中,如图3所示,深度Q网络包括主Q网络和目标Q网络,其中,主Q网络作为目标智能体的决策网络,目标Q网络作为目标智能体的价值网络,主Q网络和目标Q网络均采用卷积神经网络。在对目标智能体进行强化学习时,主Q网络根据状态计算的决策Q值作为动作值,并根据动作值产生下一时刻的状态值和奖励值,状态值、动作值以及奖励值均被存储,目标Q网络从存储中取出下一时刻的环境计算目标Q值,并以主Q网络与目标Q网络输出的决策Q值和目标Q值以及累计奖励值来更新主Q网络,同时每隔一段时间将当前主Q网络复制给目标Q网络。

实施例中,DQN通过结合深度神经网络与强化学习的Q学习算法,不仅解决了状态空间过大难以维护的问题,而且由于神经网络强大的特征提取能力,其潜力也远大于人工的特征表示。强化学习中的Q学习通过贝尔曼方程,采用时序差分的方式进行迭代更新状态-动作价值函数Q:

Q

其中,

DQN还使用了目标网络机制,即在当前Q

其中,

训练过程中,DQN采用了经验回放机制,将状态转换过程(状态s

从Buff中采样N个训练数据集,通过最小化损失函数来更新当前Q

步骤2,利用参数优化的深度Q网络产生T个时刻的目标智能体驾驶的状态动作对序列作为专家数据,其中,状态动作对中的动作取值对应Q值最小的动作。

生成状态动作对序列的过程是一个采样过程,采用获得的T个时刻小车驾驶序列状态动作对{(s

步骤3,利用强化学习的方法来训练包含生成器和判别器的SeqGAN。

实施例中,在对包含多个生成器和判别器的SeqGAN进行参数优化时,将专家数据中状态动作对作为生成器的输入,生成器用于根据输入的状态动作对作生成状态动作对,同时采用基于策略梯度蒙特卡洛搜索来模拟采样来得到一些状态动作对,该些状态动作对与生成器生成的状态动作对形成固定长度的生成状态动作对序列,该生成状态动作对序列被输入至判别器,判别器用于根据输入的生成状态动作对序列计算奖励值;

从存储中在线采样固定长度的真实状态动作对序列输入至判别器,经计算获得真实状态动作对序列的奖励值,该奖励值与生成状态动作对序列的奖励值的交叉熵为损失函数,来更新判别器和生成器的参数。

具体地,如图2所示,SeqGAN的训练过程为:

生成器的目标是生成一个序列数据Y

当没有中间奖励值时,生成器G

其中,R

GAN的目的是使生成数据与目标数据分布(训练集数据所代表的的分布)相接近,其中判别器的目标是使两者的分布最小化,目标函数可表示为:

其中,P

训练过程中,使用判别器来评估生成的序列数据,来指导生成器的训练,通过策略梯度更新的方法对生成器的参数θ进行更新:

其中,α

步骤4,基于预训练的深度Q网络和参数优化的SeqGAN的预测结果筛选用于对深度Q网络的再优化的增强数据。

具体地,将当前状态输入至参数优化的SeqGAN的生成器G中以获得T个时刻的生成状态动作对序列,将其输入到DQN模型中,利用参数优化的深度Q网络来计算生成状态动作对序列的累计奖励值,将该累计奖励值与目标智能体的深度Q网络策略得到的累计奖励值进行比较,以累计奖励值更高的状态动作对作为增强数据并连同得到的奖励值作为模型的训练数据存入经验缓冲区。

其中,目标智能体的深度Q网络策略得到的累计奖励值的获取过程为:将在线采样的状态动作对序列作为依次深度Q网络的输入,利用深度Q网络计算获得深度Q网络策略得到的累计奖励值。

步骤5,从存储中选择增强数据对深度Q网络进行参数再优化,以实现深度强化学习数据增强防御。

具体地,目标智能体从Buff中采样N个训练数据集,通过最小化当前主Q网络的决策Q值和目标Q网络的目标Q值的均方差来更新当前主Q网络的网络参数,每隔一段时间将当前主Q网络的参数复制给目标Q网络,通过数据增强的方式对模型的策略进行优化,提高了模型的鲁棒性。

实施例还提供一种基于SeqGAN的深度强化学习数据增强防御装置,包括计算机存储器、计算机处理器以及存储在所述计算机存储器中并可在所述计算机处理器上执行的计算机程序,计算机处理器执行计算机程序时实现上述基于SeqGAN的深度强化学习数据增强防御方法。

实际应用中,计算机存储器可以为在近端的易失性存储器,如RAM,还可以是非易失性存储器,如ROM,FLASH,软盘,机械硬盘等,还可以是远端的存储云。计算机处理器可以为中央处理器(CPU)、微处理器(MPU)、数字信号处理器(DSP)、或现场可编程门阵列(FPGA),即可以通过这些处理器实现基于SeqGAN的深度强化学习数据增强防御方法步骤。

上述基于SeqGAN的深度强化学习数据增强防御方法和装置主要用于基于强化学习训练过程受到攻击从而使训练数据集发生改变的场景下。该方法基于序列对抗式生成网络(SeqGAN)来生成训练策略,从而利用生成的序列数据对模型进行优化,达到训练数据增强的目的,从而提高模型的鲁棒性。

以上所述的具体实施方式对本发明的技术方案和有益效果进行了详细说明,应理解的是以上所述仅为本发明的最优选实施例,并不用于限制本发明,凡在本发明的原则范围内所做的任何修改、补充和等同替换等,均应包含在本发明的保护范围之内。

去获取专利,查看全文>

相似文献

  • 专利
  • 中文文献
  • 外文文献
获取专利

客服邮箱:kefu@zhangqiaokeyan.com

京公网安备:11010802029741号 ICP备案号:京ICP备15016152号-6 六维联合信息科技 (北京) 有限公司©版权所有
  • 客服微信

  • 服务号