目前在训练这个模型,所以重新回顾一下这篇论文,并且结合GitHub上的高分实现代码https://github.com/backtime92/CRAFT-Reimplementation,一起分析这篇论文的精彩部分,也就是采用半监督学习处理样本。因为craft需要针对字符级别的标注进行训练,但是一般的数据集比如icdar15等都是基于单词级别打标的,所以这种思想可以引导我们思考如何利用现有的粗粒度的样本来实现细粒度的GT要求。
一、相关工作
文字检测模型可以分为
1. 基于回归的文字检测,包括textboxes系列方法。主要使用卷积和anchor box来确定文字框的位置
2. 基于分割的文字检测 pixellink、Multi-scale FCN、SSTD、TextSnake
3. 端到端的文字检测 同时文字检测和识别,FOTS,EAA,MaskTextSpotter
4. 字符级别的文字检测 seglink首先检测字符,然后使用link prediction得到字符之间的连接关系。
二、CRAFT模型
本文收到WordSup启发,但是克服了其易受文字形状(矩形限制)的弱点。
首先介绍一下如何使用高斯热力图表示每个文字的位置信息和文字之间的连接关系。
每个字符的区域是一个四边形,正常的二维高斯图是中间值最大,往外值逐渐减小的二维矩阵,矩阵中元素的值在[0,1]之间;那么对于一个四边形来说,只要知道了四个角点的坐标就可以通过透视变换得到标准高斯图变换到该四边形的变换矩阵,因此由标准高斯图和变换矩阵就可以得到变换后的高斯图,这个高斯图就可以作为这个字符的region score map,对一张图上的所有文字都做这样的变换,就可以得到全图所有位置的region score,这个就可以作为样本的标签
那么字符之间的连接关系怎么得到?对于一个Word单词来说,他由很多字符组成,既然我们已经知道了字符的位置,那么很容易得到字符的几何中心点,几何中心点和上边的两个坐标连接,得到一个上三角形,和下边的两个点相连得到下三角形,这两个三角形的中心就是连接图的一侧两个坐标,由此可以得到两个字符之间的连接矩形坐标。按照上面的方式,同理可以得到整张图的affinity score map。
通过上述的操作,我们知道了怎么用高斯图得到region score和affinity score,之后模型需要学习的就是这两张热力图。
region score ground truth
image.jpeg
affinity score GT
image.jpeg
本文预测每个字符的位置,并且预测字符和字符之间的关系;模型的输出是两张feature map,其中一个用于给出每个位置char score的热力图,另一个给出affinity score的热力图。在后处理过程中,根据这两张图,采用连通域分析等传统图像处理技术,最终得到文本行的位置信息。
因为字符级别的数据比较少,所以模型训练使用弱监督方式进行。
三、模型结构
模型结构比较简单,采用一个VGG_16作为基础网络,然后按照Unet的思想上采样+融合这样,最终得到一张2个channel的输出结果。结合代码可以更加清楚的看到网络结构。
image.jpeg
损失函数
输出的是两张热力图,ground_truth也是热力图,热力图上的每个元素都是[0,1]之间的分数,因此采用全图的均方误差表示损失函数:
image.jpeg
其中Sc表示置信度,因为对于生成的样本来说,我们知道每个字符的位置,所以置信度始终为1,但是对于类似ICDAR15等数据集,他的打标方式是只对文本框打标,不对单字打标,所以需要考虑置信度,这个暂时不提。
Sr表示region score map的元素值,Sa表示affinity score map的值。
四、 训练策略
训练策略:
- 使用SynthText数据训练5万轮,
- 使用公开训练数据集fine-tune
- 使用ADAM优化器
- 在fine-tune的时候也会使用SynthText的数据保证字符区域的分隔,比率为1:5
- 采用在线难例挖掘,比率为1:3
- 采用了常见的数据扩充技术,比如建材,旋转或者颜色变化
- 对于半监督学习,生成假的GT的步骤放在额外的GPU上进行,生成之后会保存在内存中?
对于弱监督学习的数据集,需要满足下面的条件,首先需要有标注的四边形,其次需要有具体的文字信息。
- 实际满足要求的训练数据只有IC13, IC15, and IC17。因此实际训练的时候,作者采用IC15得到一个模型,使用IC13和IC17得到另一个模型。每个模型都fine-tune了25000iters
图片的长边都被resize到960,2240,2560,1600
五、样本是怎么生成的
这是这篇论文的精华部分吧,但是具体需要结合实际的代码来看。目前GitHub上有一份对于CTAFT复现+训练的高分项目,项目地址如下:https://gitee.com/shuanghaochen/CRAFT-Reimplementation(GitHub上速度比较慢,这个下载快一点)
可以结合代码看这部分。
对于样本可以分为两类,一类是SynthText这类电脑生成的样本,这些样本可以直接得到每个字符和文字的位置信息,可以直接采用来生成热力图。另一类是icdar15这类只进行了单词级别的打标的数据集。
5.1 SynthText GT生成步骤
- 读取字符框位置、读取单词、获得图片
- 随机缩放图片,对应的字符框位置也改变
- 根据图片大小和字符框位置生成region-score高斯热力图
- 根据region-score热力图,图片大小、单词、字符框生成affinity-score map
- 对图片以及热力图进行一些随机化操作,比如随机crop,flip,旋转等操作。这里需要注意,如果自己生成的文本是比较密集的文本,那么随机crop的代码需要做改动,不然
- 将gt数据缩小一倍
5.1.1第3步具体是怎么生成热力图的?图片大小和字符框chabox
- 设置一个空白的和图片大小相同的目标图
height, width = image_size[0], image_size[1]
target = np.zeros([height, width], dtype=np.uint8)
- 对于所有字符框:
2.1. 已经预先定义了一个标准大小为512512的标准热力图origin_box,对这个标准热力图过滤掉低热力区域(比如周边的阈值小于0.4255的区域),因为我们在得到模型结果的时候,只会关注score高于特定阈值的区域,得到一个regionbox
2.2. 通过charbox字符框和regionbox得到一个从标准regionbox透视变换到charbox的变换矩阵M
M = cv2.getPerspectiveTransform(np.float32(regionbox), np.float32(target_bbox))
2.3. 通过变换矩阵M,得到标准热力图origin_box变换到charbox坐标系的坐标real_target_box
real_target_box = cv2.perspectiveTransform(oribox, M)[0]
real_target_box = np.int32(real_target_box)
2.4. 得到real_target_box的xmin,ymin,xmax,ymax
xmin = real_target_box[:, 0].min()
xmax = real_target_box[:, 0].max()
ymin = real_target_box[:, 1].min()
ymax = real_target_box[:, 1].max()
2.5. 对charbox的坐标系移动到边缘之后
_target_box[:, 0] -= xmin
_target_box[:, 1] -= ymin
2.6. 重新生成透视变换矩阵,然后将标准热力图变换到real_target_box的尺寸上
_M = cv2.getPerspectiveTransform(np.float32(regionbox), np.float32(_target_box))
warped = cv2.warpPerspective(self.standardGaussianHeat.copy(), _M, (width, height))
2.7. 将target图上对应的位置设为热力图的值
5.1.2 第4步:根据region-score热力图,图片大小、单词、字符框生成affinity-score map
- 生成一张全为0的热力图
- 对每个字符以及他右侧相邻的字符执行下面的操作
2.1 得到两个box的中心坐标
center_1, center_2 = np.mean(bbox_1, axis=0), np.mean(bbox_2, axis=0)
2.2 分别得到中心坐标和两个box的上下侧坐标点构成的三角形的中心;将得到的四个点作为affinity box的四点坐标
tl = np.mean([bbox_1[0], bbox_1[1], center_1], axis=0)
bl = np.mean([bbox_1[2], bbox_1[3], center_1], axis=0)
tr = np.mean([bbox_2[0], bbox_2[1], center_2], axis=0)
br = np.mean([bbox_2[2], bbox_2[3], center_2], axis=0)
- 执行和第三步一样的操作得到即可。
5.2 对于ICDAR15等针对单词打标而非字符打标的训练样本,如何生成GT?
在论文中,给出了生成GT的方式,这是一种半监督的生成方式,首先来看icdar15的标注信息:
845,196,1271,78,1278,130,851,247,MOLESKINE
1035,280,1162,257,1169,278,1041,301,MOLESKINE
31,365,107,375,106,405,30,395,###
843,297,908,286,912,298,848,310,###
前八位是顺时针的单词坐标,最后一位是文本信息;对于不可以辨认的文字,使用###表示。
在处理这样的打标数据时,首先获得文本行对应的局部图片,然后通过模型预测得到region score;根据region score使用分水岭算法得到字符的位置信息;然后将这个坐标信息恢复到原图的坐标轴上,作为GT;这种半监督的方式体现在,根据每个单词得到的字符box都是带有置信度的;置信度的计算方式为
image.jpeg
在训练的初期,置信度会比较低,对置信度低于0.5的样本,会根据单词包含的字符数量平均划分区域,进行标注。同时设置标签置信度为0.5。
整体的过程可以用下面图的上面部分来展示:
image.jpeg
实际在训练的时候,定义Dataset的时候会传入网络Net;在数据处理的时候,已经有预测网络Net,原图image,单词Word以及单词狂word_box;
1. 获得单词局部图片
2. 将单词图片透视变换到矩形图片,归一化高度为64
3. 通过Net网络预测得到region_score的结果
img_torch = torch.from_numpy(imgproc.normalizeMeanVariance(input, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)))
img_torch = img_torch.permute(2, 0, 1).unsqueeze(0)
img_torch = img_torch.type(torch.FloatTensor).cuda()
scores, _ = net(img_torch)
region_scores = scores[0, :, :, 0].cpu().data.numpy()
region_scores = np.uint8(np.clip(region_scores, 0, 1) * 255)
bgr_region_scores = cv2.resize(region_scores, (input.shape[1], input.shape[0]))
bgr_region_scores = cv2.cvtColor(bgr_region_scores, cv2.COLOR_GRAY2BGR)
4. 使用分水岭算法得到仿真的字符框
pursedo_bboxes = watershed(input, bgr_region_scores, False)
5. 根据仿真字符框和Word的字符数量计算confidence
confidence = self.get_confidence(real_char_nums, len(pursedo_bboxes))
6. 如果计算的confidence小于0.5,则按照单词长度对每个单元格设置仿真字符框,然后设置confidence=0.5。
if confidence <= 0.5:
width = input.shape[1]
height = input.shape[0]
width_per_char = width / len(word)
for i, char in enumerate(word):
if char == ' ':
continue
left = i * width_per_char
right = (i + 1) * width_per_char
bbox = np.array([[left, 0], [right, 0], [right, height],
[left, height]])
bboxes.append(bbox)
bboxes = np.array(bboxes, np.float32)
confidence = 0.5
7. 将文本框坐标恢复到透视变换之前的坐标系空间
for j in range(len(bboxes)):
ones = np.ones((4, 1))
tmp = np.concatenate([bboxes[j], ones], axis=-1)
I = np.matrix(MM).I
ori = np.matmul(I, tmp.transpose(1, 0)).transpose(1, 0)
bboxes[j] = ori[:, :2]
8. 剩下的生成region_score以及生成affinity_score的方式和SynthText的一致。
一些结论
CRAFT模型有很好的鲁棒性,这是因为他的识别目标是字符,本身就是比较小的存在
CRAFT之所以可以有比较好的效果,不仅仅受益于字符的空间信息,还受益于识别结果的语义分隔信息,也就是说,模型也可以根据字符的语义信息获得收益。














网友评论