一、项目介绍
基于MAX78000FTHR开发板设计的手势识别器,支持石头、剪刀、布、空识别,识别结果会显示到外接的屏幕上,屏幕也会实时显示摄像头采集到的画面。
二、项目设计思路
项目分为三部分:识别模型、摄像头采集、屏幕显示。
使用板卡自带的摄像头进行图像采集,将数据传到主芯片,将处理结果显示到外接的屏幕上。
整体流程如下
1.学习官方代码,依照教程进行模型的训练、量化、部署的学习。
2.收集所需素材,通过上网查找相关数据集、自拍摄等方式收集到各种石头剪刀布的素材。
3.编写硬件驱动,使用SPI驱动LCD显示屏,驱动摄像头采集图像,并显示到LCD上。
4.编写数据加载脚本等工具,进行图像预处理与训练量化。
5.完成模型部署,根据实际情况调节代码。
整体框图
三、搜集素材的思路
1.数据集图像
对于识别训练来说,图像越多肯定效果越好,但是个人很难去采集大量的图片进行训练,因此我优先寻找可以直接使用的数据集,于是上网找了找,发现了一个美国手语识别的数据集ASL Alphabet(https://www.kaggle.com/datasets/grassknoted/asl-alphabet)。训练数据集包含 87,000 张 200x200 像素的图像。有 29 个类,其中 26 个用于字母 A-Z,3 个用于 SPACE、DELETE 和 NOTHING。
数据集里面的B、nothing、S、V是类似于石头剪刀布的图像,这里我们可以直接拿来使用。
2.自拍摄图像
为了补充训练集的不足,增加可用图像,这里我自己采集了一些图像用于训练。美信官方提供了一个图像采集的例程,可以将采集的图像传到电脑上,这里我使用截图工具将需要的图像截下来,作为数据集的补充。
以上两种获得的图像就完全足矣这次的训练使用了。
四、预训练实现过程
编写数据加载脚本
将数据处理成64*64大小
###################################################################################################
#
# Copyright (C) 2018-2020 Maxim Integrated Products, Inc. All Rights Reserved.
#
# Maxim Integrated Products, Inc. Default Copyright Notice:
# https://www.maximintegrated.com/en/aboutus/legal/copyrights.html
#
###################################################################################################
"""
RPS Datasets
"""
import os
import sys
import torchvision
from torchvision import transforms
import ai8x
def rps_get_datasets(data, load_train=True, load_test=True):
"""
rps dataset
"""
(data_dir, args) = data
path = data_dir
dataset_path = os.path.join(path, "rps_big")
is_dir = os.path.isdir(dataset_path)
if not is_dir:
print("******************************************")
print("No data!!!")
print("******************************************")
sys.exit("Dataset not found..")
training_data_path = os.path.join(data_dir, "rps_big")
training_data_path = os.path.join(training_data_path, "train")
test_data_path = os.path.join(data_dir, "rps_big")
test_data_path = os.path.join(test_data_path, "test")
# Loading and normalizing train dataset
if load_train:
train_transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ColorJitter(
brightness=(0.3, .8),
contrast=(.7, 1),
saturation=0.2,
),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply(([
transforms.ColorJitter(),
]), p=0.3),
transforms.ToTensor(),
ai8x.normalize(args=args)
])
train_dataset = torchvision.datasets.ImageFolder(root=training_data_path,
transform=train_transform)
else:
train_dataset = None
# Loading and normalizing test dataset
if load_test:
test_transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
ai8x.normalize(args=args)
])
test_dataset = torchvision.datasets.ImageFolder(root=test_data_path,
transform=test_transform)
if args.truncate_testset:
test_dataset.data = test_dataset.data[:1]
else:
test_dataset = None
return train_dataset, test_dataset
datasets = [
{
'name': 'rps_big',
'input': (3, 64, 64),
'output': ( 'b', 'nothing', 's', 'v'),
'weight': (1, 1, 1, 1),
'loader': rps_get_datasets,
},
]
开始训练
2022-11-26 07:37:24,861 - Optimizer Type: <class 'torch.optim.adam.Adam'>
2022-11-26 07:37:24,862 - Optimizer Args: {'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.0001, 'amsgrad': False}
2022-11-26 07:37:25,041 - Dataset sizes:
training=10800
validation=1200
test=4
2022-11-26 07:37:25,042 - Reading compression schedule from: schedule-asl.yaml
2022-11-26 07:37:25,046 -
2022-11-26 07:37:25,047 - Training epoch: 10800 samples (256 per mini-batch)
2022-11-26 07:37:50,724 - Epoch: [0][ 10/ 43] Overall Loss 1.368387 Objective Loss 1.368387 LR 0.001000 Time 2.567639
2022-11-26 07:38:13,264 - Epoch: [0][ 20/ 43] Overall Loss 1.313537 Objective Loss 1.313537 LR 0.001000 Time 2.410738
2022-11-26 07:38:38,605 - Epoch: [0][ 30/ 43] Overall Loss 1.245654 Objective Loss 1.245654 LR 0.001000 Time 2.451815
2022-11-26 07:39:05,066 - Epoch: [0][ 40/ 43] Overall Loss 1.188697 Objective Loss 1.188697 LR 0.001000 Time 2.500355
2022-11-26 07:39:10,427 - Epoch: [0][ 43/ 43] Overall Loss 1.173808 Objective Loss 1.173808 Top1 63.157895 LR 0.001000 Time 2.450569
2022-11-26 07:39:10,739 - --- validate (epoch=0)-----------
2022-11-26 07:39:10,740 - 1200 samples (256 per mini-batch)
2022-11-26 07:39:18,305 - Epoch: [0][ 5/ 5] Loss 0.968138 Top1 60.000000
2022-11-26 07:39:18,584 - ==> Top1: 60.000 Loss: 0.968
2022-11-26 07:39:18,592 - ==> Confusion:
[[211 12 35 24]
[ 57 160 14 65]
[ 54 172 29 44]
[ 1 2 0 320]]
2022-11-26 07:39:18,604 - ==> Best [Top1: 60.000 Sparsity:0.00 Params: 60080 on epoch: 0]
最后一次训练与验证
2022-11-26 10:34:42,820 - Saving checkpoint to: logs/2022.11.26-073724/qat_checkpoint.pth.tar
2022-11-26 10:34:42,832 -
2022-11-26 10:34:42,836 - Training epoch: 10800 samples (256 per mini-batch)
2022-11-26 10:35:09,183 - Epoch: [99][ 10/ 43] Overall Loss 0.345190 Objective Loss 0.345190 LR 0.001000 Time 2.634476
2022-11-26 10:35:33,291 - Epoch: [99][ 20/ 43] Overall Loss 0.345397 Objective Loss 0.345397 LR 0.001000 Time 2.522606
2022-11-26 10:35:57,030 - Epoch: [99][ 30/ 43] Overall Loss 0.345680 Objective Loss 0.345680 LR 0.001000 Time 2.473005
2022-11-26 10:36:19,786 - Epoch: [99][ 40/ 43] Overall Loss 0.345482 Objective Loss 0.345482 LR 0.001000 Time 2.423636
2022-11-26 10:36:24,455 - Epoch: [99][ 43/ 43] Overall Loss 0.345504 Objective Loss 0.345504 Top1 100.000000 LR 0.001000 Time 2.363106
2022-11-26 10:36:24,660 - --- validate (epoch=99)-----------
2022-11-26 10:36:24,663 - 1200 samples (256 per mini-batch)
2022-11-26 10:36:31,351 - Epoch: [99][ 5/ 5] Loss 0.343993 Top1 100.000000
2022-11-26 10:36:31,619 - ==> Top1: 100.000 Loss: 0.344
2022-11-26 10:36:31,625 - ==> Confusion:
[[282 0 0 0]
[ 0 296 0 0]
[ 0 0 299 0]
[ 0 0 0 323]]
2022-11-26 10:36:31,637 - ==> Best [Top1: 100.000 Sparsity:0.00 Params: 60080 on epoch: 99]
2022-11-26 10:36:31,638 - Saving checkpoint to: logs/2022.11.26-073724/qat_checkpoint.pth.tar
2022-11-26 10:36:31,654 - --- test ---------------------
2022-11-26 10:36:31,655 - 4 samples (256 per mini-batch)
2022-11-26 10:36:32,722 - Test: [ 1/ 1] Loss 0.343015 Top1 100.000000
2022-11-26 10:36:32,938 - ==> Top1: 100.000 Loss: 0.343
2022-11-26 10:36:32,940 - ==> Confusion:
[[1 0 0 0]
[0 1 0 0]
[0 0 1 0]
[0 0 0 1]]
不得不说,100%的识别率还是很夸张的,可能是与识别的图像简单,分类少的原因。
五、实现结果展示
识别空
识别布
识别石头
识别剪刀
六、主要代码
图像显示
void lcd_show_sampledata(uint32_t* data0, uint32_t* data1, uint32_t* data2, int xcord, int ycord,
int length)
{
int i;
int j;
int x;
int y;
int r;
int g;
int b;
int scale = 1.2;
uint32_t color;
uint8_t* ptr0;
uint8_t* ptr1;
uint8_t* ptr2;
x = 0;
y = 0;
for (i = 0; i < length; i++)
{
ptr0 = (uint8_t*)&data0[i];
ptr1 = (uint8_t*)&data1[i];
ptr2 = (uint8_t*)&data2[i];
for (j = 0; j < 4; j++)
{
r = ptr0[j];
g = ptr1[j];
b = ptr2[j];
color = RGB(r, g, b); // convert to RGB565
MXC_TFT_WritePixel(xcord * scale + 2 * x * scale, ycord * scale + 2 * y * scale, scale, scale, color);
x += 1;
if (x >= (IMAGE_SIZE_X))
{
x = 0;
y += 1;
if ((y + 6) >= (IMAGE_SIZE_Y))
return;
}
}
}
}
图像采集与CNN识别
// Capture a single camera frame.
printf("\nCapture a camera frame %d\n", ++frame);
capture_camera_img();
// Copy the image data to the CNN input arrays.
printf("Copy camera frame to CNN input buffers.\n");
process_camera_img(input_0_camera, input_1_camera, input_2_camera);
convert_img_unsigned_to_signed(input_0_camera, input_1_camera, input_2_camera);
cnn_init(); // Bring state machine into consistent state
cnn_load_weights(); // Load kernels
cnn_load_bias();
cnn_configure(); // Configure state machine
cnn_start(); // Start CNN processing
load_input(); // Load data input via FIFO
MXC_TMR_SW_Start(MXC_TMR0);
while (cnn_time == 0)
__WFI(); // Wait for CNN
softmax_layer();
printf("Time for CNN: %d us\n\n", cnn_time);
printf("Classification results:\n");
for (i = 0; i < CNN_NUM_OUTPUTS; i++)
{
digs = (1000 * ml_softmax[i] + 0x4000) >> 15;
tens = digs % 10;
digs = digs / 10;
result[i] = digs;
printf("[%7d] -> Class %d %8s: %d.%d%%\r\n", ml_data[i], i, classes[i], digs, tens);
}
printf("\n");
判断并显示识别结果
if (result[0] > 0) //适应实际修改
{
//布
TFT_Print(buff, 30, 55, font_2, sprintf(buff, "Paper "));
printf("User choose: %s \r\n", classes[0]);
}
else if (result[1] > 60)
{
//石头
TFT_Print(buff, 30, 55, font_2, sprintf(buff, "Rock "));
printf("User choose: %s \r\n", classes[1]);
}
else if (result[2] > 60)
{
//剪刀
TFT_Print(buff, 30, 55, font_2, sprintf(buff, "Scissors"));
printf("User choose: %s \r\n", classes[2]);
}
else if (result[3] > 60)
{
//空
TFT_Print(buff, 30, 55, font_2, sprintf(buff, "Empty "));
printf("User choose: %s \r\n", classes[3]);
}
else
{
TFT_Print(buff, 30, 55, font_2, sprintf(buff, "Unknown "));
}
TFT_Print(buff, 205, 55, font_2, sprintf(buff, "Paper:%d ",result[0]));
TFT_Print(buff, 205, 75, font_2, sprintf(buff, "Rock:%d ",result[1]));
TFT_Print(buff, 205, 95, font_2, sprintf(buff, "Scissors:%d ",result[2]));
TFT_Print(buff, 205, 115, font_2, sprintf(buff, "Empty:%d ",result[3]));
七、问题与下一步计划
一开始本打算做完整手语识别的,但是实际部署到板卡上发现识别情况与期望相差非常大,参考官方提供的例程,发现都是一些比较简单的识别,并没有太多种类识别部署到板卡上,于是就转变思路仅仅选择石头剪刀布与空这几个简单的手势去做,但也出现了布难以识别(识别阈值被我调节到大于0就判断为布),非这三种手势也会给出输出的问题(应该是空数据集数据种类太少的原因,应该把一些其他手势放到空里面)。
下一步计划继续进行AI相关的学习,试着将更多的手势去部署到板卡上测试,通过修改训练参数让代码的适应能力更好,将更多的图片换成自己采集的图片,这样的话估计能大幅度提高实际用摄像头采集识别的准确率。