项目目标:能够实现对传送带上多种工业钻头方向的检测,具有足够的响应速率,99%以上的召回率,即尽可能做到不漏检。
思路方法:将关注点聚焦到钻头特定的有明显特征的部位,从而可以缩小检测区域,减小网络规模、计算量,提高检测精度和速度。这里选择聚焦在钻头头部位置,如果是正方向,则是检测到钻头头部。否则检测到钻头尾部。
数据集获取:通过对不同类型的钻头拍照,截取头部尾部区域,然后进行数据扩增。
1.原始采集数据为4种不同型号的钻头照片,每种3张,为PNG格式。通过PS抠图抠出钻头头部明显特征部分,为类别1。抠出尾部部分,为类别2。
2.将钻头头部和尾部照片与钻头所在环境的背景图片进行合成,每次截取背景中一片256x256的区域,随机地将钻头特征图片放在背景中的区域。然后将钻头特征图片旋转一定的角度,重复合成过程,值得近两万张数据集。
数据生成代码如下图所示:
import os
import shutil
import cv2
import copy
import numpy as np
import random
def opencv_rotate(img, angle):
"""
图片旋转,默认应该是逆时针转动
:param img:
:param angle:
:return:
"""
h, w = img.shape[:2] # 图像的(行数,列数,色彩通道数)
borderValue = (0, 0, 0, 0)
# 颜色空间转换?
if img.shape[-1] == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
elif img.shape[-1] == 1:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
center = (w / 2, h / 2)
scale = 1.0
# 2.1获取M矩阵
"""
M矩阵
[
cosA -sinA (1-cosA)*centerX+sinA*centerY
sinA cosA -sinA*centerX+(1-cosA)*centerY
]
"""
# cv2.getRotationMatrix2D(获得仿射变化矩阵)
M = cv2.getRotationMatrix2D(center, angle, scale)
# 2.2 新的宽高,radians(angle) 把角度转为弧度 sin(弧度)
new_H = int(w * np.fabs(np.sin(np.radians(angle))) + h * np.fabs(np.cos(np.radians(angle))))
new_W = int(h * np.fabs(np.sin(np.radians(angle))) + w * np.fabs(np.cos(np.radians(angle))))
# 2.3 平移
M[0, 2] += (new_W - w) / 2
M[1, 2] += (new_H - h) / 2
# cv2.warpAffine(进行仿射变化)
return cv2.warpAffine(img, M, (new_W, new_H), borderValue=borderValue)
def merge_png_jpg(img_png, img_jpg, merge_factor, x_offset=0, y_offset=0):
height, width= img_png.shape[:2]
img_jpg = copy.deepcopy(img_jpg)
for i in range(0, 3):
img_jpg[y_offset:y_offset + height, x_offset:x_offset + width, i] = ((1 - merge_factor) * img_jpg[y_offset:y_offset + height, x_offset:x_offset + width, i]
+ merge_factor * img_png[:height, :width, i])
# cv2.imshow('img_new', img_jpg)
# cv2.waitKey(0)
return img_jpg
global num_saves
def translation_transformation(img_png, img_jpg, angle, label_class, save_path, png_name):
png_height, png_width = img_png.shape[:2]
jpg_height, jpg_width = img_jpg.shape[:2]
max_offset_y = jpg_height - 2*128
max_offset_x = jpg_width - 2*128
merge_factor = cv2.split(img_png)[-1] / 255.0
for y_offset in range(2*128, max_offset_y, 100):
for x_offset in range(2*128, max_offset_x, 100):
img_new = merge_png_jpg(img_png, img_jpg, merge_factor, x_offset, y_offset)
x_rand = random.randint(-(int)(128-png_width/2), (int)(128-png_width/2))
y_rand = random.randint(-(int)(128-png_height/2), (int)(128-png_height/2))
y_left = (int)(y_offset+png_height/2 - 128 + x_rand)
y_right = (int)(y_offset+png_height/2 + 128 + x_rand)
x_left = (int)(x_offset+png_width/2 - 128 + y_rand)
x_right = (int)(x_offset+png_width/2 + 128 + y_rand)
img_new = img_new[y_left:y_right, x_left:x_right, :]
# cv2.imshow('img_new', img_new)
# cv2.waitKey(0)
global num_saves
save_dir = "train"
if not (num_saves % 4):
save_dir = "valid"
elif 2 == (num_saves % 4):
save_dir = "test"
img_name = str(num_saves) + '_' + str(angle) + \
'_' + str(y_offset) + '_' + str(x_offset) + '_' + png_name
cv2.imwrite(save_path + '/' + save_dir + "/images/" + img_name + ".jpg", img_new)
gt_file = open(save_path + '/' + save_dir + "/labels/" + img_name + ".txt", 'w')
gt_data = [label_class, (x_offset + png_width / 2.0) / jpg_width, (y_offset + png_height / 2.0) / jpg_height, png_width / jpg_width, png_height / jpg_height]
gt_file.writelines(" ".join(str(i) for i in gt_data))
gt_file.close()
num_saves = num_saves + 1
print(num_saves)
def data_augment_ps(png_path, jpg_path, save_path):
img_png = cv2.imread(png_path, cv2.IMREAD_UNCHANGED)
img_jpg = cv2.imread(jpg_path, cv2.IMREAD_UNCHANGED)
# img_jpg = cv2.resize(img_jpg, (0,0), fx = 0.5, fy = 0.5)
png_name = png_path.split('/')[-1].split('.')[0]
for angle in range(-20, 10, 4):
translation_transformation(opencv_rotate(img_png, angle), img_jpg, angle, 0, save_path, png_name)
for angle in range(160, 190, 4):
translation_transformation(opencv_rotate(img_png, angle), img_jpg, angle, 1, save_path, png_name)
def main():
png_path = "/home/lz/DataDisk/datasets/jz_drill_recognize/for_max78000/foreground_img_tail"
jpg_path = "/home/lz/DataDisk/datasets/jz_drill_recognize/for_max78000/background_img"
save_path = "/home/lz/RamDisk/jz_drill_ps_crop_aug_data"
if os.path.exists(save_path):
shutil.rmtree(save_path)
os.makedirs(save_path + "/train/images")
os.makedirs(save_path + "/train/labels")
os.makedirs(save_path + "/valid/images")
os.makedirs(save_path + "/valid/labels")
os.makedirs(save_path + "/test/images")
os.makedirs(save_path + "/test/labels")
global num_saves
num_saves = 0
for i, png_name in enumerate(os.listdir(png_path)):
for j, jpg_name in enumerate(os.listdir(jpg_path)):
data_augment_ps(os.path.join(png_path, png_name), os.path.join(jpg_path, jpg_name), save_path)
if __name__ == "__main__":
main()
模型训练
由于是一个二分类问题,这里采用了官方的ai85net-cd.py中的模型,该模型本来是实现对猫和狗做分类,这里为了初步快速验证实现的效果,故而选择了这个模型。为了减少对代码的改动,这里将数据集的格式按照该demo的要求安排。几带有钻头头部区域的类别放在cats目录中,带有钻头尾部区域的类别放在dogs类别中。
训练:这里直接启用官方的训练脚本即可,100epoch后top1达到100%,手动中止训练
./scripts/train_catsdogs.sh
2022-11-26 17:58:22,263 - Training epoch: 17972 samples (256 per mini-batch)
2022-11-26 17:58:23,402 - Epoch: [100][ 10/ 71] Overall Loss 0.000018 Objective Loss 0.000018 LR 0.000216 Time 0.113771
2022-11-26 17:58:24,192 - Epoch: [100][ 20/ 71] Overall Loss 0.000012 Objective Loss 0.000012 LR 0.000216 Time 0.096369
2022-11-26 17:58:24,961 - Epoch: [100][ 30/ 71] Overall Loss 0.000012 Objective Loss 0.000012 LR 0.000216 Time 0.089875
2022-11-26 17:58:25,741 - Epoch: [100][ 40/ 71] Overall Loss 0.000010 Objective Loss 0.000010 LR 0.000216 Time 0.086900
2022-11-26 17:58:26,529 - Epoch: [100][ 50/ 71] Overall Loss 0.000010 Objective Loss 0.000010 LR 0.000216 Time 0.085259
2022-11-26 17:58:27,311 - Epoch: [100][ 60/ 71] Overall Loss 0.000010 Objective Loss 0.000010 LR 0.000216 Time 0.084088
2022-11-26 17:58:28,054 - Epoch: [100][ 70/ 71] Overall Loss 0.000010 Objective Loss 0.000010 Top1 100.000000 LR 0.000216 Time 0.082687
2022-11-26 17:58:28,108 - Epoch: [100][ 71/ 71] Overall Loss 0.000010 Objective Loss 0.000010 Top1 100.000000 LR 0.000216 Time 0.082278
2022-11-26 17:58:28,156 - --- validate (epoch=100)-----------
2022-11-26 17:58:28,157 - 1996 samples (256 per mini-batch)
2022-11-26 17:58:28,946 - Epoch: [100][ 8/ 8] Loss 0.000024 Top1 100.000000
2022-11-26 17:58:28,994 - ==> Top1: 100.000 Loss: 0.000
将模型量化为int8类型
python quantize.py ../ai8x-training/logs/2022.11.26-174709/qat_best.pth.tar ../jz_cat_dog_synthesis/q8.pth.tar --device MAX78000
评估模型的性能,这里测试集为实际的钻头图片,一共有624张
python train.py --model ai85cdnet --dataset cats_vs_dogs --confusion --evaluate --exp-load-weights-from ../jz_cat_dog_synthesis/q8.pth.tar -8 --device MAX78000 "$@"
+----------------------+-------------+-----------+
| Key | Type | Value |
|----------------------+-------------+-----------|
| arch | str | ai85cdnet |
| compression_sched | dict | |
| epoch | int | 109 |
| extras | dict | |
| optimizer_state_dict | dict | |
| optimizer_type | type | Adam |
| state_dict | OrderedDict | |
+----------------------+-------------+-----------+
2022-11-26 22:37:39,101 - => Checkpoint['extras'] contents:
+-----------------+--------+---------------+
| Key | Type | Value |
|-----------------+--------+---------------|
| best_epoch | int | 109 |
| best_mAP | int | 0 |
| best_top1 | float | 100.0 |
| clipping_method | str | MAX_BIT_SHIFT |
| current_mAP | int | 0 |
| current_top1 | float | 100.0 |
+-----------------+--------+---------------+
2022-11-26 22:37:39,102 - Loaded compression schedule from checkpoint (epoch 109)
2022-11-26 22:37:39,104 - => loaded 'state_dict' from checkpoint '../jz_cat_dog_synthesis/q8.pth.tar'
2022-11-26 22:37:39,110 - Optimizer Type: <class 'torch.optim.sgd.SGD'>
2022-11-26 22:37:39,110 - Optimizer Args: {'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0.0001, 'nesterov': False}
2022-11-26 22:37:39,299 - Dataset sizes:
training=17972
validation=1996
test=624
2022-11-26 22:37:39,299 - --- test ---------------------
2022-11-26 22:37:39,300 - 624 samples (256 per mini-batch)
2022-11-26 22:37:40,133 - Test: [ 3/ 3] Loss 0.437118 Top1 87.500000
2022-11-26 22:37:40,189 - ==> Top1: 87.500 Loss: 0.437
2022-11-26 22:37:40,190 - ==> Confusion:
[[203 65]
[ 13 343]]
结果分析
初测结果还是比较理想的,87%的top1准确率,虽然离实际应用的精度要求还有距离,但是还有很多设计改进的空间。在yolo v5模型下可以达到99%以上的准确率。后续会根据MAX78000的硬件特点设计性能更好的网路模型,此外数据集的丰富度还有待进一步提升。