技术领域
本发明涉及图像处理技术领域,具体为一种基于注意力机制的脱机英文手写识别方法。
背景技术
为将书写于纸上的信息数字化以便于后期的查询检索,最简单的方法是采取专人录入,但这种方式极大浪费人力物力以及时间。光学文字识别(Optical CharacterRecognition,OCR)实现了机器“读懂”人类手写文字,但由于脱机手写文字的风格迥异,一般的卷积神经网络模型提取出的图像特征表示力不强,如2019年Carbonell等人提出的针对全文本的检测识别方法,其对手写字符串的识别错误率非常高,需要很多后续的处理工作,无法满足实际应用的需求。
发明内容
为了解决现有的手写识别方法错误率较高的技术问题,本发明提供一种基于注意力机制的脱机英文手写识别方法,其可以提高对手写字符串的识别率,满足实际应用的需求。
本发明的技术方案是这样的:一种基于注意力机制的脱机英文手写识别方法,其包括以下步骤:
S1:选取书写来源不同的脱机手写单词图像,并进行预处理,添加标签,获得训练数据集;
S2:构建脱机英文手写识别模型;
S3:将已知标签的所述训练数据集输入到所述脱机英文手写识别模型中进行训练,获得训练好的所述脱机英文手写识别模型;
S4:将待识别手写单词的图像输入到训练好的所述脱机英文手写识别模型中,获得识别结果;
其特征在于:
所述脱机英文手写识别模型包括:依次连接的基于卷积神经网络构建的图像特征提取模块、基于双向长短期记忆网络模型构建的序列特征提取模块、生成模块;
所述图像特征提取模块直接从输入的图像中提取特征,生成特征序列,其包括:9个卷积层,4个池化层,3个注意力模块层以及2个批量标准化层;其中,将9个卷积层分成5个卷积块,前四个所述卷积块包括连续的两个卷积层,最后一个所述卷积块包括一个卷积层,每个卷积层后面都设置一个修正线性单元激活函数;
前四个所述卷积块后面分别跟着一个所述池化层,前三个所述卷积块的所述池化层前面分别设置一个注意力模块层;
第四个所述卷积块中,每个卷积层和修正线性单元激活函数之间设置一个批量标准化层;
所述序列特征提取模块从图像特征中学习序列联系,其包括依次连接的双向长短期记忆网络模型、全连接层;
所述生成模块把特征分布转化成标签序列,其包括:损失函数层。
其进一步特征在于:
所述图像特征提取模块中,卷积层全部采用3×3大小的卷积核;池化层采用最大池化或2×1尺寸;
所述双向长短期记忆网络模型中,隐藏层单元设为256;
所述序列特征提取模块中,损失函数层采用CTC损失函数;
所述注意力模块层中的计算过程,详细如下所示:
a1:将输入特征图F分别输入到通道注意力模块Ms和空间注意模块Mc中,并联的获取到特征图的通道注意力映射M
其中,
MLP为由多层感知机组成的共享网络对这两个不同的空间背景进行计算;在多层感知机中,隐层神经元个数为C/8,输出层单元个数为C;
W
其中,
σ表示Sigmoid激活函数;f
a2:将空间注意模块输出的特征图F
a3:将所述3D的注意力特征图F′与所述注意力模块层的所述输入特征图F的卷积结果相加得到最终的细化特征图F"';其计算过程如下所示:
F"=Relu(f
F″′=F′+F″;
步骤S3中,对所述脱机英文手写识别模型中进行训练时,学习率更新策略采用如下方法:
lr=base_lr*decay_rate
其中:
base_lr为初始化学习率,global_step为当前迭代次数,decay_rate为学习率衰减系数,decay_steps为常数。
本发明提供的一种基于注意力机制的脱机英文手写识别方法,构建的脱机英文手写识别模型包括基于注意力机制的卷积神经网络和双向长短期记忆网络模型,基于注意力机制构建的卷积神经网络提取的图像特征,特征图经注意力模块后更能聚焦有用特征而非无用的手写拖拽特征,使得提取的图像特征更加关注有用信息,忽略无用信息,进而提高了图像识别的准确率;本专利技术方案中的注意力模块层,没有采用原始注意力模块CBAM中先将输入特征图送入通道注意力模块,再将输出结果送入空间注意力模块的串联方式,而是需要输入的特征图同时经过通道注意力和空间注意力模块,并联地获取到各自的注意力映射,之后分别与输入特征图做点乘得到通道注意力特征图和空间注意力特征图,在增强有用的特征表示的同时抑制无用特征的干扰;比起传统的CBAM,本专利技术方案中的注意力模块层避免了先经过通道注意力模块再经过空间注意力模块后,注意力映射M
附图说明
图1为本专利的脱机英文手写识别模型的网络结构示意图;
图2为本专利的注意力模块层的结构示意图。
具体实施方式
本发明一种基于注意力机制的脱机英文手写识别方法,其包括以下步骤。
S1:选取书写来源不同的脱机手写单词图像,并进行预处理,添加标签,获得训练数据集;
本实施例中,预处理操作包括:将脱机手写单词图像的高度规范为32像素,宽度也等比例缩放。
S2:构建脱机英文手写识别模型;
如图1所示,脱机英文手写识别模型包括:依次连接的基于卷积神经网络(CNN)构建的图像特征提取模块(Convolutional layers)、基于双向长短期记忆网络模型构建的序列特征提取模块(RecurrentLayers)、生成模块(Transcription layers)。
图像特征提取模块(Convolutional layers)直接从输入的图像中提取特征,生成特征序列,其包括:9个卷积层,4个池化层,3个注意力模块层以及2个批量标准化层;其中,将9个卷积层分成5个卷积块,前四个卷积块包括连续的两个卷积层,最后一个卷积块包括一个卷积层,每个卷积层后面都设置一个修正线性单元(ReLU)激活函数;通过修正线性单元(Rectified Linear Units,ReLU)激活函数避免梯度消失问题,进而提高识别模型对手写字符串的分类识别率;
前四个卷积块后面分别跟着一个池化层,前三个卷积块的池化层前面分别设置一个注意力模块层;
第四个卷积块中,每个卷积层和修正线性单元激活函数之间设置一个批量标准化层(BN层),批量标准化层将特征图的数据分布重新规范,使非线性函数的输入值远离梯度饱和区,加快了网络的训练速度;
序列特征提取模块(RecurrentLayers)从图像特征中学习序列联系,其包括依次连接的双向长短期记忆网络模型、全连接层;双向长短期记忆网络模型中,隐藏层单元设为256,后连接全连接层;
生成模块(Transcription layers)把特征分布转化成标签序列,其包括:损失函数层,本专利中损失函数采用CTC损失函数;生成模块(Transcription layers)将序列特征提取模块(RecurrentLayers)输出的特征对应到相应的标签,获得输出的字符(OutputSequence),进行后续处理,获得预测序列;
如图1所示,Input image输入到图像特征提取模块(Convolutional layers)中,提取的特征以为Feature Map形式输出,Feature Map输入到序列特征提取模块(RecurrentLayers)从图像特征中学习序列联系,经双向长短期记忆网络模型帮助对传入的图像特征进行预测,提高预测模型对字符预测的速度;最后基于生成模块(Transcription layers)中的CTC损失函数定位到字符对应的标签,进行后续处理后输出预测序列;其中,双向长短期记忆网络模型中,隐藏层单元设为256,后连接全连接层。由于我们的实验是针对脱机手写单词的识别,字符字典中除了大小写的52个英文字母和4个常见符号,CTC损失函数中还要求引入“blank”标签,所以在双向长短期记忆网络的输出层设置57个输出。
在图像特征提取模块中,卷积层全部采用3×3大小的卷积核;池化层采用最大池化或2×1尺寸,最大池化能够将高层的稀疏特征继续传递下去,更多地保留图像的有用信息,提高后续预测的准确率。
具体网络结构以及详细参数见表1;其中,k代表卷积核大小/局部窗口大小,n为卷积核个数(通道数),s表示步长,p表示padding方式,hidden_units表示双向长短期记忆网络的隐藏单元个数,w表示权重矩阵大小。
表1:网络结构详细参数
如图2所示,注意力模块层中的计算过程,详细如下所示:
a1:将输入特征图F分别输入到通道注意力模块Ms和空间注意模块Mc中,并联的获取到特征图的通道注意力映射M
其中,
MLP为由多层感知机组成的共享网络对这两个不同的空间背景进行计算;在多层感知机(MLP)中,隐层神经元个数为C/8,输出层单元个数为C;
W
其中,
a2:将空间注意模块输出的特征图F
a3:将3D的注意力特征图F′与注意力模块层的输入特征图F的卷积结果相加得到最终的细化特征图(Refined Feature)F"';其计算过程如下所示:
F"=Relu(f
F″′=F′+F″。
综上,通道注意力映射M
S3:将已知标签的训练数据集输入到脱机英文手写识别模型中进行训练,获得训练好的脱机英文手写识别模型;
对脱机英文手写识别模型中进行训练时,模型参数初始化设置如下:基础学习率(base_lr)设为0.1,训练轮次设为30000,单批次大小(batch_size)为16,衰减系数(decay_rate)设置为0.8;
学习率更新策略采用如下方法:
lr=base_lr*decay_rate
其中:
base_lr为初始化学习率,global_step为当前迭代次数,decay_rate为学习率衰减系数,本实施例中设置为0.8,decay_steps为常数,本实施例中设置为2000。
S4:将待识别手写单词的图像输入到训练好的脱机英文手写识别模型中,获得识别结果;本发明不需要手工提取特征,可以端到端地对脱机手写单词图像进行识别,方法简单且识别的字符错误率(character error rate,CER)低。
为了验证本发明方法的优越性,在IAM脱机英文手写数据集上进行测试;IAM脱机英文手写数据集由657个不同作者手写的1539个扫描文本页面组成,对应于从LOB语料库中提取的英语文本。每张文本页又按文本行和单词切分。取数据库中1904张脱机手写单词图像进行实验,为提高本方法的可信度,取两个现有识别方法进行对比实验,其中方法一是2018年Sueiras等人提出的基于sequence to sequence框架的识别方法,方法二是2019年Carbonell等人提出的针对全文本的检测识别方法,最终的对比测试结果显示在表2。
表2:在IAM数据库上的字符错误率
从表2中可以看出,本发明提出的方法在IAM数据库上字符错误率(charactererror rate,CER)更低。
机译: 手写英文字符和数字字母识别方法
机译: 一种基于笔画识别的在线手写韩文识别方法
机译: 一种基于模糊推理的手写字符识别方法