首页> 中国专利> 基于强化学习和知识蒸馏的多类别不平衡故障分类方法

基于强化学习和知识蒸馏的多类别不平衡故障分类方法

摘要

本发明公开了一种基于强化学习和知识蒸馏的多类别不平衡故障分类方法,该方法结合层次聚类、知识蒸馏和强化学习等算法,用来解决多类别不平衡故障分类问题。对于多类别故障分类问题,首先针对不平衡问题中同质类别样本之间存在相似性、异质类样本之间存在较大差异的特点使用层次聚类将多类别聚类为几个簇类,根据不同簇类分别建立学生网络进行细粒度化分类,再用知识蒸馏方法兼顾全局信息,最后结合强化学习迭代学习样本权重,从而提高不平衡故障分类效果。在此过程中,需要设计合理的奖励函数配合细粒度知识蒸馏分类器去优化样本权重。相比其他对比方法,本发明的方法有良好的效果和适用性。

著录项

  • 公开/公告号CN113222035A

    专利类型发明专利

  • 公开/公告日2021-08-06

    原文格式PDF

  • 申请/专利权人 浙江大学;

    申请/专利号CN202110549644.6

  • 发明设计人 张新民;范赛特;魏驰航;宋执环;

    申请日2021-05-20

  • 分类号G06K9/62(20060101);G06N3/04(20060101);G06N3/08(20060101);

  • 代理机构33200 杭州求是专利事务所有限公司;

  • 代理人贾玉霞

  • 地址 310058 浙江省杭州市西湖区余杭塘路866号

  • 入库时间 2023-06-19 12:07:15

说明书

技术领域

本发明属于工业过程监测领域,尤其涉及一种基于强化学习和知识蒸馏的多类别不平衡故障分类方法。

背景技术

在机器学习或深度学习分类中,类别样本数量不平衡是一个非常普遍的问题,广泛存在于各个领域,例如生物信息学,智能电网,医学成像,故障诊断。大多数现有的分类方法都基于以下假设:观测数据的基本分布是相对均衡的。但是,实际工业数据集通常会违反此假设,并呈现出偏斜的分布甚至是极度不平衡的类别样本数量分布。例如,数据驱动的故障分类是工业过程监测的重要组成部分,由于故障发生的频率不同,它们表现出不平衡的偏斜分布。在这种情况下,如果假定所有类别都具有同等的重要性,则分类器会倾向于分对频繁(多数)类别的样本而不是不频繁(少数)类别的样本。因此,迫切需要提出恰当的方法来消除不平衡的类别分布的负面影响,而又不过度牺牲任何多数类别或少数类别的准确性。

发明内容

本发明的目的在于提供一种基于强化学习和知识蒸馏的多类别不平衡故障分类方法,其能对多数类不平衡的分类问题,获得较好的故障分类结果。具体技术方案如下:

一种基于强化学习和知识蒸馏的多类别不平衡故障分类方法,包括以下步骤:

S1:离线建模

S1.1:收集K个类别的历史离线工业过程数据样本,其中包含故障数据和正常数据;

S1.2:计算每个类别特征中心点

S1.3:通过基于Ward-Linkage的层次聚类,将同质类的类别特征中心分配在一个簇类中,最终将所有类别特征中心u

S1.4:使用高斯伯努利限制玻尔兹曼机,分别基于所有样本以及每个簇类中样本进行训练,其中,所有样本训练得到的高斯伯努利限制玻尔兹曼机参数为教师网络的预训练参数;基于每个簇类中样本训练得到的高斯伯努利限制玻尔兹曼机参数为对应的学生网络的预训练参数;

S1.5:基于所述的教师网络的预训练参数,采用所有样本,通过微调技术,训练多类别不平衡的教师网络,得到的logit作为所有学生网络的软目标;

S1.6:训练完教师网络之后,所有学生网络都通过综合交叉熵损失一起训练;根据包含所述软目标和硬目标的综合损失,采用每个簇类中样本,通过微调技术进行训练,将所有学生网络得到的logit拼接在一起,组成学生网络的综合logit;各个学生网络拼接的每个logit中值的位置对应于原先类别顺序;所述硬目标为样本的真实标签;

S1.7:使用强化学习结合知识蒸馏的输出来学习样本权重,并结合学习后的样本权重、教师网络和各个学生网络的输出构建损失函数;

S1.8:重复S1.5~S1.7,进行强化学习模型和知识蒸馏模型迭代训练,直到模型收敛;

S2:在线应用测试

S2.1:获取在线样本;

S2.2:将在线样本分类到S1.3层次聚类得到的C个簇类的其中一个簇类中;

S2.3:基于S1.8训练得到的知识蒸馏模型中的教师网络和各个学生网络,计算在线样本经过所在的簇类对应的学生网络得到的logit,和通过强化学习模型得到样本权重w

进一步地,所述S1.2中的特征中心点计算具体为:

其中,u

进一步地,所述S1.3具体为:

基于Ward-Linkage进行层次聚类,直到最后所有样本都聚成一个簇类。主要有以下步骤:

①在初始化过程中,将每个样本独立的归为一个簇类中;计算每两个簇类中心之间的相似度;

②找到两个最近的簇类,并将它们归为一个簇类,因此簇类总数减少1个;

③重新计算新生成簇类的中心与每个旧簇类中心之间的相似度;所述簇类的中心为一个簇类的所有样本的平均值;

④重复②和③,直到所有样本归为一个簇类,聚类算法结束;

⑤选择所需的最终聚类后的簇类数,作为最终的簇类数,即C的值。

进一步地,所述S1.4中的高斯伯努利限制玻尔兹曼机具有两层全连接的结构,分为可见单元

其中v

所述高斯伯努利限制玻尔兹曼机的目标函数为:

其中,x

通过随机梯度上升方法最大化以找到最佳θ,完成对所述高斯伯努利限制玻尔兹曼机的训练:

其中,θ中的w和b用作知识蒸馏神经网络第一层的初始参数。

进一步地,所述S1.5通过梯度下降法训练教师网络,其中,教师网络的交叉熵损失函数如下:

其中

进一步地,所述S1.6通过梯度下降法训练学生网络,其中,学生网络的交叉熵损失函数如下:

其中

进一步地,所述S1.7具体为:

设定π

(1)初始化样本权重(动作):

a

其中w

(2)计算老师-学生网络的加权交叉熵损失:

其中

(3)计算奖励r

r

其中F

(4)获取状态s

(5)更新策略π

进一步地,将在线样本分到对应簇类中,所述S2.2具体为:

在线样本分类到对应的簇类中,其公式如下:

其中c为在线样本的簇类类,

进一步地,所述S2.3具体为:

用强化学习学习到的π

w

在线样本经过学生网络得到的输出为:

logit=w

其中,f

本发明的有益效果如下:

本发明对于多类别的不平衡故障分类问题具有独特的效果,由于同质类别样本之间存在相似性、异质类样本之间存在较大差异的特点,使得本发明在通过聚类方法得到的簇类的基础上,更加细粒度的通过多个学生网络来解决不平衡的故障分类问题。同时通过教师网络的引导使得各个学生网络不仅能学习到簇类中同质类别的决策边界,也能学习到总体的数据分布信息。不仅如此,进一步结合强化学习,不断结合识蒸馏网络进行迭代,结合样本类别数量与样本在分布中的作用,获取样本权重,增加少数类样本的权重,减少多数类样本的权重,使得故障分类效果更好,准确率更高。

附图说明

图1为本发明方法采用的基础方法的结构图;

图2为本发明方法的结构图;

图3为使用的数据集生成的工艺流程图;

图4为使用的数据集样本数量分布示意图;

图5为本发明方法训练奖励和测试G-mean的曲线图;

图6为通过层次聚类得到的树状图;

图7为所有对比方法10次运行后绘制的箱线图;

图8为分类最后一层隐层的数据通过t-SNE降维后的2D映射图。(a)为MLP最后一层隐层输出的2D映射图;(b)为SMOTE-MLP最后一层隐层输出的2D映射图;(c)为Cosen-MLP最后一层隐层输出的2D映射图;(d)为CSDBN-DE最后一层隐层输出的2D映射图;(e)为TU-MLP最后一层隐层输出的2D映射图;(f)为KD最后一层隐层输出的2D映射图;(g)为本发明最后一层隐层输出的2D映射图。

具体实施方式

下面根据附图和优选实施例详细描述本发明,本发明的目的和效果将变得更加明白,应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。

针对多类别的不平衡分布问题,本发明提出了一种新的基于强化学习和知识蒸馏的多类别不平衡故障分类方法。

本发明针对多类别的不平衡分布下的故障分类问题,划定离线和在线数据集,首先使用知识蒸馏方法进行分类或识别故障的类别。再针对不平衡问题中同质类别样本之间存在相似性、异质类样本之间存在较大差异的特点,采用层次聚类方法,根据类别中心点的聚类结果,将所有类别样本进行聚类,从而获得细粒度簇类。最后针对每个簇类进行细粒度故障分类。因此,对于某个簇类中,都将建立一个学生模型,最后进行拼接,进行多学生模型一起优化。在教师模型的全局信息的指导下,并结合多学生模型细粒度的进行故障分类。不仅如此,进一步结合强化学习,不断结合识蒸馏网络进行迭代,结合样本类别数量与样本在分布中的作用,获取样本权重,增加少数类样本的权重,减少多数类样本的权重,使得故障分类效果更好,相比其他现存方法,本发明的方法有良好的效果和适用性。

如图1和2所示,本发明的基于强化学习和知识蒸馏的多类别不平衡故障分类方法,包括以下步骤:

S1:离线建模

S1.1:收集K个类别的历史离线工业过程数据样本,其中包含故障数据样本和正常数据样本;

S1.2:计算每个类别特征中心点

其中,g

S1.3:通过基于Ward-Linkage的层次聚类,将所有类别特征中心u

①在初始化过程中,将每个样本独立的归为一个簇类中。计算每两个簇类中心之间的距离(也称为相似度);

②找到两个最近的簇类,并将它们归为一个簇类,因此簇类总数减少1个;

③重新计算新生成簇类的中心与每个旧簇类中心之间的相似度(一个簇类的所有样本的平均值代表该簇类的中心);

④重复②和③,直到所有样本归为一个簇类,聚类算法结束;

⑤选择所需的最终聚类后的簇类数,作为最终簇类数,即C的值。

整个聚类过程实际上是在构建一棵树。在构建过程中,第②步将设置一个阈值。当两个最近的簇类中心之间的距离大于此阈值时,则认为迭代已终止。另一个关键步骤是第三步,有很多方法可以确定两个聚类之间的相似性。常用的相似性度量包括Ward Linkage,Single Linkage、Complete Linkage和Average Linkage策略。在发明中,由于WardLinkage策略通常提供较高的聚类性能,因此采用Ward Linkage策略。Ward Linkage由两个聚类之间的平方误差和ESS计算得出,其目标函数是每次合并后ESS的最小增量,ESS定义如下:

S1.4:使用高斯伯努利限制玻尔兹曼机,分别基于所有样本以及每个聚类中样本进行训练。其中,所有样本训练得到的高斯伯努利限制玻尔兹曼机参数为教师网络的预训练参数;基于每个簇类中样本训练得到的高斯伯努利限制玻尔兹曼机参数为对应的学生网络的预训练参数。

高斯伯努利限制玻尔兹曼机具有两层全连接的结构,分为可见单元(或数据变量)

其中v

通常,将导致高(低)能量的配置(v,h)分别设置为低(高)概率计算的一部分。所有可见单元或隐藏单元都是有条件的独立单元。因此,高斯伯努利限制玻尔兹曼机的可见节点和隐藏节点的概率分布可以由下式给出:

其中σ(x)是逻辑斯蒂sigmoid函数

现有的大多数基于高斯伯努利限制玻尔兹曼机的模型都是通过对比差异(CD)学习策略来处理数据非线性的,该策略将实值数据映射到隐特征空间。对数似然估计值可通过随机梯度上升方法最大化以找到最佳θ:

通过迭代获得高斯伯努利限制玻尔兹曼机的最优参数θ。θ中的w和b用作知识蒸馏神经网络第一层的初始参数。

S1.5:基于所述的教师网络的预训练参数,采用所有样本,通过微调技术,通过梯度下降法训练多类别不平衡的教师网络,得到的logit作为所有学生网络的软目标。计算教师网络ft的交叉熵损失函数如下:

其中

S1.6:训练完教师网络之后,所有的学生网络都通过综合交叉熵损失一起训练。根据包含了软目标(教师网络的logit)和硬目标(真实标签)的综合损失,采用每个簇类中样本,通过微调技术,通过梯度下降法进行训练所有学生网络。学生网络的综合logit由所有学生网络的logit拼接在一起。各个学生网络拼接的每个logit中值的位置对应于原先类别顺序。学生网络f

其中

S1.7:使用强化学习结合知识蒸馏的输出来学习样本权重,并结合学习后的样本权重、教师网络和各个学生网络的输出构建损失函数;

S1.8:重复S1.5~S1.7,进行强化学习模型和知识蒸馏模型迭代训练,直到模型收敛;

设定π

(1)初始化样本权重(动作):

a

其中w

(2)计算老师-学生网络的加权交叉熵损失:

其中

(3)计算奖励r

r

其中F

(4)获取状态s

(5)更新策略π

S2:在线应用测试

S2.1:获取在线样本;

S2.2:基于S1.3层次聚类得到的簇类信息,将在线样本分类到对应的簇类中。在线样本分类到对应的簇类中,其公式如下:

其中c为在线样本的簇类类,

S2.3:基于S1.8训练得到的知识蒸馏模型中的教师网络和各个学生网络,计算在线样本经过所在的簇类对应的学生网络得到的logit,和通过强化学习模型得到样本权重w

用强化学习学习到的π

w

在线样本经过学生网络得到的输出为:

logit=W

其中,f

以下结合一个具体的工业例子来说明本发明的有效性。使用田纳西州伊士曼(TE)工业基准来评估所提出的方法。TE过程是由伊士曼化学公司根据实际化学过程开发的工业仿真平台,已广泛用于测试过程监控和故障诊断方法的有效性。TE过程的流程如图3所示。

表1:每个故障类别TE过程训练样本数量设定

TE数据中正常样本数量为8000。表1为每个故障类别TE过程训练样本数量设定,测试样本数量设定为2000。TE数据的过程变量由34维,故障类别有28个,如图4所示。选取对比方法有MLP(多层感知机)、SMOTE-MLP(合成少数类过采样技术的MLP)、CoSen-MLP(代价敏感MLP)、CSDBN-DE(差分演化的代价敏感深度信念网络)、TU-MLP(可训练的降采样器结合MLP)、KD(知识蒸馏)和本发明(基于强化学习和知识蒸馏的多类别不平衡故障分类方法)。

通过基于强化学习和知识蒸馏的多类别不平衡故障分类方法在TE过程训练样本上训练得到各个学生模型。通过离线训练得到的学生模型对在线样本(测试集)进行预测,得到的结果如表2所示:

表2:在TE过程数据上各个对比方法的分类性能

从表2中可以看出,所提出的基于强化学习和知识蒸馏的多类别不平衡故障分类方法的F1随着不平衡率的上升在更多的类别上优于对比方法,且随着不平衡程度的提高,本发明相比其他对比方法的优势越明显。综合所有对比方法在所有类别上的结果,本发明提出的方法可以在最终的Macro-F1和Gmean指标上明显优于其他方法。

图5为本发明的训练奖励和测试G-mean的曲线图,可以看出算法收敛较为稳定,并能够根据设定达到较优的性能。图6为用层次聚类方法得到的树状图,虚线为决策线,总共分为3个簇类。图7为所有对比方法10次运行后绘制的箱线图,本发明相对其他对比方法性能更好,更稳定。

为了方法优越性更加直观和明显,绘制了各个分类模型最后一层隐藏的输出经过t-SNE后得到的2D图,如图8所示。图8(g)为本发明的2D映射图,能够从图中看出,经过基于强化学习和知识蒸馏的多类别不平衡故障分类方法,获得降维2D图中的各个类别的边界更加明显,这充分体现了算法的分类性能得到了提高。

如上所述,本发明中所提的基于强化学习和知识蒸馏的多类别不平衡故障分类方法,具有令人满意的分类效果。

去获取专利,查看全文>

相似文献

  • 专利
  • 中文文献
  • 外文文献
获取专利

客服邮箱:kefu@zhangqiaokeyan.com

京公网安备:11010802029741号 ICP备案号:京ICP备15016152号-6 六维联合信息科技 (北京) 有限公司©版权所有
  • 客服微信

  • 服务号