技术领域
本发明属于信息技术领域,涉及自然语言处理、文本分类及多标签分类,具体涉及一种基于改进GraphRNN的多标签文本分类模型及分类方法。
背景技术
多标签分类是机器学习领域中一个重要的学习任务,被广泛应用于文本分类、图 像标注、推荐系统等实际场景。在多标签分类问题中,样本可以被分配到多个标签上。假设 样本空间
从利用标签关联的角度出发,目前多标签分类方法主要可分为三种:1、一阶方法,典型算法如Binary Relevance、ML-KNN,这类方法简单有效,但忽略了标签关联;2、二阶方法,典型算法如Rank-SVM、Calibrated Label Ranking,这类方法只考虑标签对两两之间的联系,如排序问题中相关标签和不相关标签的排列关系;3、高阶方法,这类方法能建模多个标签之间的关联,但需要预定义顺序,存在累计误差,典型的传统算法如ClassifierChains、Ensemble Classifier Chains。
在以上三类方法中,目前围绕高阶方法展开的研究居多。由于循环神经网络在处理序列数据上的表现优异,近年来不少研究利用循环神经网络建模标签高阶关联,在序列到序列(Sequence to Sequence,Seq2Seq)模型下将多标签分类转换为序列生成问题,但模型易受标签顺序影响。为缓解标签顺序带来的影响,有研究将多标签分类视为标签集合预测,但无法区分标签关联程度,而标签间关联应有不同程度之分。
总的来说,现有的多标签分类方法,在建模标签高阶关联时,要么受标签顺序限制,要么没有考虑标签关联的具体程度。
发明内容
本发明针对多标签分类中如何利用标签关联的问题,提出了一种基于改进GraphRNN的多标签文本分类模型及分类方法,根据标签共现关系构建标签图数据,将多标签分类转换为标签图生成问题,可避免预定义标签顺序;利用标签共现信息,建模标签关联程度,可以更细致地建模标签关联。
为解决上述技术问题,本发明采用的技术方案如下:
一种基于改进GraphRNN的多标签文本分类模型,所述分类模型包括编码器、解码器和Graph2Seq模块,输入信息由编码器进行编码,送入解码器生成标签图,再由Graph2Seq模块将生成的标签图转换为标签集;
所述解码器由改进GraphRNN构成,具体是:基于GraphRNN图生成模型并对其进行改进,改进GraphRNN由节点生成和边生成组成,其中,所述节点生成添加softmax模块后生成标签节点,建模标签关联,所述边生成由二分类改为多分类,建模标签不同的关联程度。
进一步地,所述节点生成为:
节点生成以“
GRU为门控循环单元,
进一步地,所述边生成为:
在预测出节点
经过softmax模块得到第
当预测节点标签为“
进一步地,所述编码器具体为:
单个样本的文本信息
在0时刻,隐状态
进一步地,所述Graph2Seq模块具体为:根据改进GraphRNN得到的标签图,Graph2Seq模块在标签图上进行广度优先搜索得到最终的标签集,作为多标签分类结果。
一种基于改进GraphRNN的多标签文本分类模型的分类方法,包括以下步骤:
步骤1:将原始样本的标签集转换为标签图;
步骤2:文本预处理,包括分词、词语向量化、划分数据集;
步骤3:划分数据集,分为训练集、验证集、测试集;
步骤4:训练模型,在验证集上调整超参数前驱节点个数
步骤5:将新样本送入训练好的模型,预测对应的标签结果。
进一步地,所述步骤1具体为:
将单个样本的所有标签对视为共现标签对,对于
根据
根据
标签连边确定之后构建标签图。
进一步地,在步骤3中,将数据集进行随机划分,数据集中80%数量的数据作为训练集,而验证集和测试集各为10%数量的数据。
与现有技术相比,本发明的有益效果是:将多标签分类转换为图生成问题,可缓解标签顺序给模型带来的影响,提升了预测结果在instance-F1和label-F1指标的表现效果,不仅可以建模标签关联,还可以建模标签关联程度,从而更细致地建模标签关联。
附图说明
图1是本发明多标签文本分类模型示意图。
图2是基于本发明分类模型的分类方法流程图。
图3是将单个文本的所有标签对视为共现标签对示意图。
图4是本发明中展示的标签集转换为标签图的过程示意图。
具体实施方式
下面结合附图和具体实施方式对本发明作进一步详细的说明。
一、多标签文本分类模型组成
本发明分类模型由编码器(Encoder)、解码器(改进GraphRNN)和Graph2Seq构成。本发明模型框架如图1所示,输入信息由Encoder进行编码,送入改进GraphRNN进行标签图生成,最后Graph2Seq将生成的标签图转换为标签集,作为分类结果。
1、编码器(Encoder)
负责对输入文本信息进行编码。单个样本的文本信息。单个样本的文本信息
单个样本的文本信息经过GRU进行编码,第
2、解码器(改进GraphRNN)
GraphRNN是You等提出的图生成模型,包括节点生成(Node-level RNN)和边生成 (Edge-level RNN)两部分,但缺少节点输出模块,并且GraphRNN中的边生成为二分类问题, 无法区分多种标签关联程度,因此本发明通过改进GraphRNN,使其既能输出节点也能输出 邻接向量。改进GraphRNN同样由节点生成(Node-level RNN)和边生成(Edge-level RNN)组 成,改进点包括两个方面:
(1)Node-level RNN:节点生成以“
(2)Edge-level RNN:在预测出节点
经过softmax得到第
在图1中,当预测节点标签为“
3、Graph2Seq模块
Graph2Seq模块将生成的标签图转换为标签集,作为分类结果。具体地,由改进 GraphRNN生成的节点集合
二、原理说明
为了将seq2seq用于多标签分类时减少标签顺序带来的影响,本发明将多标签视 为集合,为了描述这种集合,本发明使用标签图来表示标签集。将节点集合表示为
在式(10)中,
三、基于改进GraphRNN的多标签分类流程
如图2所示,基于改进GraphRNN的多标签分类流程包括如下几个步骤:
1、数据转换:将原始样本的标签集转换为标签图;
在建立模型之前,需要将文本对应的原始标签集转换为标签图
根据
标签连边确定之后即可构建标签图,图4展示了标签集{A,B,C}转换为标签图的过程。
2、文本预处理:分词,词语向量化,划分数据集;
3、划分数据集:将数据集进行随机划分,数据集中80%数量的数据作为训练集,而验证集和测试集各为10%数量的数据。
4、训练、测试模型;
训练模型,在验证集上调整超参数前驱节点个数
对比方法说明如下:
(1) Binary Relevance(BR):将多标签分类问题转换为二分类问题,没有利用标签之间的相关性。
(2) Classifier Chains(CC):将多个BR级联起来,前一分类器输出作为后一分类器输入,该方法能考虑到标签之间的高阶关联。
(3) Ensemble Classifier Chains(ECC):在Classifier Chains的基础上,选择不同的标签顺序,结合集成学习训练模型。
(4) seq2seq-GRU:在seq2seq模型下基于GRU,按标签频次降序训练模型,生成多标签序列。
(5) set-RNN:将多标签视为标签集合,直接生成多标签集合。
(6)改进GraphRNN:本发明方法,将原始问题转换为标签图生成问题,在图上进行BFS得到分类结果。
5、将新样本送入训练好的模型,预测对应的标签结果。
本发明将多标签分类问题转换为图生成问题,因此在训练模型之前,需要将原始样本标签集转换为与之对应的标签图,进一步训练模型。模型训练完成后,即可用于预测新样本的标签。
机译: 基于语义表示模型的文本分类方法和装置,以及计算机设备
机译: 文本分类模型训练方法,文本分类方法和装置,以及电子设备
机译: 建立文本分类模型的方法和装置,以及文本分类方法和装置